diff --git a/CHANGELOG.md b/CHANGELOG.md index 31d84fafa1..173548c2d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,7 @@ - PR #3144: Adding Ability to Set Arbitrary Cmake Flags in ./build.sh - PR #3155: Eliminate unnecessary warnings from random projection test - PR #3176: Add probabilistic SVM tests with various input array types +- PR #3180: FIL: `blocks_per_sm` support in Python ## Bug Fixes - PR #3179: Remove unused metrics.cu file diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index c255c5e4fc..d85afa8c51 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -535,9 +535,9 @@ class ForestInference(Base): def load_from_treelite_model(self, model, output_class=False, algo='auto', threshold=0.5, - storage_type='auto'): - """ - Creates a FIL model using the treelite model + storage_type='auto', + blocks_per_sm=0): + """Creates a FIL model using the treelite model passed to the function. Parameters @@ -574,22 +574,34 @@ class ForestInference(Base): can fail if 8-byte nodes are not enough to store the forest, e.g. if there are too many nodes in a tree or too many features + blocks_per_sm : integer (default=0) + (experimental) Indicates how the number of thread blocks to lauch + for the inference kernel is determined. + - 0 (default) - launches the number of blocks proportional to the + number of data rows; + - >= 1 - attempts to lauch blocks_per_sm blocks per SM. This will + fail if blocks_per_sm blocks result in more threads than the + maximum supported number of threads per GPU. Even if successful, + it is not guaranteed that blocks_per_sm blocks will run on an SM + concurrently. Returns ---------- fil_model A Forest Inference model which can be used to perform inferencing on the random forest/ XGBoost model. + """ if isinstance(model, TreeliteModel): # TreeliteModel defined in this file return self._impl.load_from_treelite_model( - model, output_class, algo, threshold, str(storage_type), 0) + model, output_class, algo, threshold, str(storage_type), + blocks_per_sm) else: # assume it is treelite.Model return self._impl.load_from_treelite_model_handle( model.handle.value, output_class, algo, threshold, - str(storage_type), 0) + str(storage_type), blocks_per_sm) @staticmethod def load_from_sklearn(skl_model, @@ -597,6 +609,7 @@ class ForestInference(Base): threshold=0.50, algo='auto', storage_type='auto', + blocks_per_sm=0, handle=None): """ Creates a FIL model using the scikit-learn model passed to the @@ -629,6 +642,16 @@ class ForestInference(Base): - False - create a dense forest - True - create a sparse forest; requires algo='NAIVE' or algo='AUTO' + blocks_per_sm : integer (default=0) + (experimental) Indicates how the number of thread blocks to lauch + for the inference kernel is determined. + - 0 (default) - launches the number of blocks proportional to the + number of data rows; + - >= 1 - attempts to lauch blocks_per_sm blocks per SM. This will + fail if blocks_per_sm blocks result in more threads than the + maximum supported number of threads per GPU. Even if successful, + it is not guaranteed that blocks_per_sm blocks will run on an SM + concurrently. Returns ---------- @@ -644,7 +667,8 @@ class ForestInference(Base): tl_model = tl_skl.import_model(skl_model) cuml_fm.load_from_treelite_model( tl_model, algo=algo, output_class=output_class, - storage_type=str(storage_type), threshold=threshold) + storage_type=str(storage_type), threshold=threshold, + blocks_per_sm=blocks_per_sm) return cuml_fm @staticmethod @@ -653,6 +677,7 @@ class ForestInference(Base): threshold=0.50, algo='auto', storage_type='auto', + blocks_per_sm=0, model_type="xgboost", handle=None): """ @@ -677,6 +702,16 @@ class ForestInference(Base): storage_type : string (default='auto') In-memory storage format to be used for the FIL model. See documentation in `FIL.load_from_treelite_model` + blocks_per_sm : integer (default=0) + (experimental) Indicates how the number of thread blocks to lauch + for the inference kernel is determined. + - 0 (default) - launches the number of blocks proportional to the + number of data rows; + - >= 1 - attempts to lauch blocks_per_sm blocks per SM. This will + fail if blocks_per_sm blocks result in more threads than the + maximum supported number of threads per GPU. Even if successful, + it is not guaranteed that blocks_per_sm blocks will run on an SM + concurrently. model_type : string (default="xgboost") Format of the saved treelite model to be load. It can be 'xgboost', 'lightgbm'. @@ -694,7 +729,8 @@ class ForestInference(Base): algo=algo, output_class=output_class, storage_type=str(storage_type), - threshold=threshold) + threshold=threshold, + blocks_per_sm=blocks_per_sm) return cuml_fm def load_using_treelite_handle(self, @@ -702,7 +738,8 @@ class ForestInference(Base): output_class=False, algo='auto', storage_type='auto', - threshold=0.50): + threshold=0.50, + blocks_per_sm=0): """ Returns a FIL instance by converting a treelite model to FIL model by using the treelite ModelHandle passed. @@ -724,6 +761,16 @@ class ForestInference(Base): storage_type : string (default='auto') In-memory storage format to be used for the FIL model. See documentation in `FIL.load_from_treelite_model` + blocks_per_sm : integer (default=0) + (experimental) Indicates how the number of thread blocks to lauch + for the inference kernel is determined. + - 0 (default) - launches the number of blocks proportional to the + number of data rows; + - >= 1 - attempts to lauch blocks_per_sm blocks per SM. This will + fail if blocks_per_sm blocks result in more threads than the + maximum supported number of threads per GPU. Even if successful, + it is not guaranteed that blocks_per_sm blocks will run on an SM + concurrently. Returns ---------- @@ -735,7 +782,7 @@ class ForestInference(Base): output_class, algo, threshold, str(storage_type), - 0) + blocks_per_sm) # DO NOT RETURN self._impl here!! return self diff --git a/python/cuml/test/test_fil.py b/python/cuml/test/test_fil.py index 5fb968b10a..cf0d024933 100644 --- a/python/cuml/test/test_fil.py +++ b/python/cuml/test/test_fil.py @@ -376,6 +376,25 @@ def test_output_storage_type(storage_type, small_classifier_and_preds): assert np.allclose(fil_preds, xgb_preds_int, 1e-3) +@pytest.mark.skipif(has_xgboost() is False, reason="need to install xgboost") +@pytest.mark.parametrize('storage_type', ['dense', 'sparse']) +@pytest.mark.parametrize('blocks_per_sm', [1, 2, 3, 4]) +def test_output_blocks_per_sm(storage_type, blocks_per_sm, + small_classifier_and_preds): + model_path, X, xgb_preds = small_classifier_and_preds + fm = ForestInference.load(model_path, + output_class=True, + storage_type=storage_type, + threshold=0.50, + blocks_per_sm=blocks_per_sm) + + xgb_preds_int = np.around(xgb_preds) + fil_preds = np.asarray(fm.predict(X)) + fil_preds = np.reshape(fil_preds, np.shape(xgb_preds_int)) + + assert np.allclose(fil_preds, xgb_preds_int, 1e-3) + + @pytest.mark.parametrize('output_class', [True, False]) @pytest.mark.skipif(has_xgboost() is False, reason="need to install xgboost") def test_thresholding(output_class, small_classifier_and_preds):