Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] blocks_per_sm FIL parameter in Python. #3180

Merged
merged 5 commits into from
Nov 24, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
- PR #3137: Reorganize Pytest Config and Add Quick Run Option
- PR #3144: Adding Ability to Set Arbitrary Cmake Flags in ./build.sh
- PR #3155: Eliminate unnecessary warnings from random projection test
- PR #3180: FIL: `blocks_per_sm` support in Python

## Bug Fixes
- PR #3069: Prevent conversion of DataFrames to Series in preprocessing
Expand Down
65 changes: 56 additions & 9 deletions python/cuml/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,9 @@ class ForestInference(Base):
def load_from_treelite_model(self, model, output_class=False,
algo='auto',
threshold=0.5,
storage_type='auto'):
"""
Creates a FIL model using the treelite model
storage_type='auto',
blocks_per_sm=0):
"""Creates a FIL model using the treelite model
passed to the function.

Parameters
Expand Down Expand Up @@ -574,29 +574,42 @@ class ForestInference(Base):
can fail if 8-byte nodes are not enough
to store the forest, e.g. if there are
too many nodes in a tree or too many features
blocks_per_sm : integer (default=0)
(experimental) Indicates how the number of thread blocks to lauch
for the inference kernel is determined.
- 0 (default) - launches the number of blocks proportional to the
number of data rows;
- >= 1 - attempts to lauch blocks_per_sm blocks per SM. This will
fail if blocks_per_sm blocks result in more threads than the
maximum supported number of threads per GPU. Even if successful,
it is not guaranteed that blocks_per_sm blocks will run on an SM
concurrently.

Returns
----------
fil_model
A Forest Inference model which can be used to perform
inferencing on the random forest/ XGBoost model.

"""
if isinstance(model, TreeliteModel):
# TreeliteModel defined in this file
return self._impl.load_from_treelite_model(
model, output_class, algo, threshold, str(storage_type), 0)
model, output_class, algo, threshold, str(storage_type),
blocks_per_sm)
else:
# assume it is treelite.Model
return self._impl.load_from_treelite_model_handle(
model.handle.value, output_class, algo, threshold,
str(storage_type), 0)
str(storage_type), blocks_per_sm)

@staticmethod
def load_from_sklearn(skl_model,
output_class=False,
threshold=0.50,
algo='auto',
storage_type='auto',
blocks_per_sm=0,
handle=None):
"""
Creates a FIL model using the scikit-learn model passed to the
Expand Down Expand Up @@ -629,6 +642,16 @@ class ForestInference(Base):
- False - create a dense forest
- True - create a sparse forest;
requires algo='NAIVE' or algo='AUTO'
blocks_per_sm : integer (default=0)
(experimental) Indicates how the number of thread blocks to lauch
for the inference kernel is determined.
- 0 (default) - launches the number of blocks proportional to the
number of data rows;
- >= 1 - attempts to lauch blocks_per_sm blocks per SM. This will
fail if blocks_per_sm blocks result in more threads than the
maximum supported number of threads per GPU. Even if successful,
it is not guaranteed that blocks_per_sm blocks will run on an SM
concurrently.

Returns
----------
Expand All @@ -644,7 +667,8 @@ class ForestInference(Base):
tl_model = tl_skl.import_model(skl_model)
cuml_fm.load_from_treelite_model(
tl_model, algo=algo, output_class=output_class,
storage_type=str(storage_type), threshold=threshold)
storage_type=str(storage_type), threshold=threshold,
blocks_per_sm=blocks_per_sm)
return cuml_fm

@staticmethod
Expand All @@ -653,6 +677,7 @@ class ForestInference(Base):
threshold=0.50,
algo='auto',
storage_type='auto',
blocks_per_sm=0,
model_type="xgboost",
handle=None):
"""
Expand All @@ -677,6 +702,16 @@ class ForestInference(Base):
storage_type : string (default='auto')
In-memory storage format to be used for the FIL model.
See documentation in `FIL.load_from_treelite_model`
blocks_per_sm : integer (default=0)
(experimental) Indicates how the number of thread blocks to lauch
for the inference kernel is determined.
- 0 (default) - launches the number of blocks proportional to the
number of data rows;
- >= 1 - attempts to lauch blocks_per_sm blocks per SM. This will
fail if blocks_per_sm blocks result in more threads than the
maximum supported number of threads per GPU. Even if successful,
it is not guaranteed that blocks_per_sm blocks will run on an SM
concurrently.
model_type : string (default="xgboost")
Format of the saved treelite model to be load.
It can be 'xgboost', 'lightgbm'.
Expand All @@ -694,15 +729,17 @@ class ForestInference(Base):
algo=algo,
output_class=output_class,
storage_type=str(storage_type),
threshold=threshold)
threshold=threshold,
blocks_per_sm=blocks_per_sm)
return cuml_fm

def load_using_treelite_handle(self,
model_handle,
output_class=False,
algo='auto',
storage_type='auto',
threshold=0.50):
threshold=0.50,
blocks_per_sm=0):
"""
Returns a FIL instance by converting a treelite model to
FIL model by using the treelite ModelHandle passed.
Expand All @@ -724,6 +761,16 @@ class ForestInference(Base):
storage_type : string (default='auto')
In-memory storage format to be used for the FIL model.
See documentation in `FIL.load_from_treelite_model`
blocks_per_sm : integer (default=0)
(experimental) Indicates how the number of thread blocks to lauch
for the inference kernel is determined.
- 0 (default) - launches the number of blocks proportional to the
number of data rows;
- >= 1 - attempts to lauch blocks_per_sm blocks per SM. This will
fail if blocks_per_sm blocks result in more threads than the
maximum supported number of threads per GPU. Even if successful,
it is not guaranteed that blocks_per_sm blocks will run on an SM
concurrently.

Returns
----------
Expand All @@ -735,7 +782,7 @@ class ForestInference(Base):
output_class,
algo, threshold,
str(storage_type),
0)
blocks_per_sm)
# DO NOT RETURN self._impl here!!
return self

Expand Down
20 changes: 20 additions & 0 deletions python/cuml/test/test_fil.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,26 @@ def test_output_storage_type(storage_type, small_classifier_and_preds):
assert np.allclose(fil_preds, xgb_preds_int, 1e-3)


@pytest.mark.skipif(has_xgboost() is False, reason="need to install xgboost")
@pytest.mark.parametrize('storage_type',
['dense', 'sparse', 'sparse8'])
@pytest.mark.parametrize('blocks_per_sm', [0, 1, 2, 3, 4])
canonizer marked this conversation as resolved.
Show resolved Hide resolved
def test_output_blocks_per_sm(storage_type, blocks_per_sm,
small_classifier_and_preds):
model_path, X, xgb_preds = small_classifier_and_preds
fm = ForestInference.load(model_path,
output_class=True,
storage_type=storage_type,
threshold=0.50,
blocks_per_sm=blocks_per_sm)

xgb_preds_int = np.around(xgb_preds)
fil_preds = np.asarray(fm.predict(X))
fil_preds = np.reshape(fil_preds, np.shape(xgb_preds_int))

assert np.allclose(fil_preds, xgb_preds_int, 1e-3)


@pytest.mark.parametrize('output_class', [True, False])
@pytest.mark.skipif(has_xgboost() is False, reason="need to install xgboost")
def test_thresholding(output_class, small_classifier_and_preds):
Expand Down