-
Notifications
You must be signed in to change notification settings - Fork 184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CatBoost converter #392
CatBoost converter #392
Changes from 3 commits
bef5e25
eb80ab1
8e5af27
59f2683
67a148d
d9359ec
442a105
1a5f30e
68aa952
6e0aea6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,6 +44,29 @@ def convert_libsvm(model, name=None, initial_types=None, doc_string='', target_o | |
custom_conversion_functions, custom_shape_calculators) | ||
|
||
|
||
def convert_catboost(model, name=None, initial_types=None, doc_string='', target_opset=None, | ||
targeted_onnx=onnx.__version__, custom_conversion_functions=None, custom_shape_calculators=None): | ||
try: | ||
from catboost.utils import convert_to_onnx_object | ||
except ImportError: | ||
raise RuntimeError('CatBoost is not installed or need to be updated. ' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: "needs to be updated." |
||
'Please install/upgrade CatBoost to use this feature.') | ||
|
||
if custom_conversion_functions: | ||
warnings.warn('custom_conversion_functions is not supported. Please set it to None.') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why include these converter arguments if they are not supported? It might be better to remove the arguments entirely. In the code above for the keras converter, these arguments were deprecated, which is why the warning messages were necessary. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought all convertors have pretty the same interface and thus added the args). I have discussed the matter with the member of CatBoost team. I will create a pr to change CatBoost converter interface to pass those args to the CatBoost's side. CatBoost team may implement the functionality in the future. I will update my pr when the change is released if it is ok. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The signature in onnxmltools is not always the same. Only in sklearn-onnx. So I would either remove the parameter either raise an exception if the parameter is not None. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, I removed arguments that are not supported |
||
if custom_shape_calculators: | ||
warnings.warn('custom_shape_calculators is not supported. Please set it to None.') | ||
|
||
export_parameters = { | ||
'onnx_domain': 'ai.catboost', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not using existing domains? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! You are right, I will change it to the ai.onnx There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed |
||
'onnx_model_version': 0, | ||
'onnx_doc_string': doc_string, | ||
'onnx_graph_name': name | ||
} | ||
|
||
return convert_to_onnx_object(model, export_parameters=export_parameters) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need take care of the target_opset argument, which specify what's the opset version will be used in the generated ONNX model. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, now I pass the target_opset to Catboost and check it there |
||
|
||
|
||
def convert_lightgbm(model, name=None, initial_types=None, doc_string='', target_opset=None, | ||
targeted_onnx=onnx.__version__, custom_conversion_functions=None, custom_shape_calculators=None): | ||
if not utils.lightgbm_installed(): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -212,6 +212,9 @@ def convert_model(model, name, input_types): | |
model, prefix = convert_lightgbm(model, name, input_types), "LightGbm" | ||
else: | ||
raise RuntimeError("Unable to convert model of type '{0}'.".format(type(model))) | ||
elif model.__class__.__name__.startswith("Cat"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any better fingerprint to identify the original model? |
||
from onnxmltools.convert import convert_catboost | ||
model, prefix = convert_catboost(model, name, input_types), "Cat" | ||
elif isinstance(model, BaseEstimator): | ||
from onnxmltools.convert import convert_sklearn | ||
model, prefix = convert_sklearn(model, name, input_types), "Sklearn" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,3 +15,4 @@ scipy | |
svm | ||
wheel | ||
xgboost<=1.0.2 | ||
catboost |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
""" | ||
Tests for CatBoostRegressor and CatBoostClassifier converter. | ||
""" | ||
import unittest | ||
import numpy | ||
import catboost | ||
from sklearn.datasets import make_regression, make_classification | ||
from onnxmltools.convert import convert_catboost | ||
from onnxmltools.utils import dump_data_and_model, dump_single_regression, dump_multiple_classification | ||
|
||
|
||
class TestCatBoost(unittest.TestCase): | ||
def test_catboost_regressor(self): | ||
X, y = make_regression(n_samples=100, n_features=4, random_state=0) | ||
catboost_model = catboost.CatBoostRegressor(task_type='CPU', loss_function='RMSE', | ||
n_estimators=10, verbose=0) | ||
dump_single_regression(catboost_model) | ||
|
||
catboost_model.fit(X.astype(numpy.float32), y) | ||
catboost_onnx = convert_catboost(catboost_model, name='CatBoostRegression', | ||
doc_string='test regression') | ||
self.assertTrue(catboost_onnx is not None) | ||
dump_data_and_model(X.astype(numpy.float32), catboost_model, catboost_onnx, basename="CatBoostReg-Dec4") | ||
|
||
def test_catboost_bin_classifier(self): | ||
X, y = make_classification(n_samples=100, n_features=4, random_state=0) | ||
catboost_model = catboost.CatBoostClassifier(task_type='CPU', loss_function='CrossEntropy', | ||
n_estimators=10, verbose=0) | ||
|
||
catboost_model.fit(X.astype(numpy.float32), y) | ||
|
||
catboost_onnx = convert_catboost(catboost_model, name='CatBoostBinClassification', | ||
doc_string='test binary classification') | ||
self.assertTrue(catboost_onnx is not None) | ||
# onnx runtime returns zeros as class labels | ||
# dump_data_and_model(X.astype(numpy.float32), catboost_model, catboost_onnx, basename="CatBoostBinClass") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this line be uncommented? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, this part has a problem :( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This must be fixed and it is probably an error somewhere in the onnx graph. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for the information, I reported your reply to the Catboost team members and I will update my pr after they fix the bug There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It works now with the new onnxruntime version |
||
|
||
def test_catboost_multi_classifier(self): | ||
X, y = make_classification(n_samples=10, n_informative=8, n_classes=3, random_state=0) | ||
catboost_model = catboost.CatBoostClassifier(task_type='CPU', loss_function='MultiClass', n_estimators=100, | ||
verbose=0) | ||
|
||
dump_multiple_classification(catboost_model) | ||
|
||
catboost_model.fit(X.astype(numpy.float32), y) | ||
catboost_onnx = convert_catboost(catboost_model, name='CatBoostMultiClassification', | ||
doc_string='test multiclass classification') | ||
self.assertTrue(catboost_onnx is not None) | ||
dump_data_and_model(X.astype(numpy.float32), catboost_model, catboost_onnx, basename="CatBoostMultiClass") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The other converters keeps arguments like "targeted_onnx=onnx.version, custom_conversion_functions=None, custom_shape_calculators=None" for the backward compatibility, if there is a brand new one, these arguments could be dropped.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, did so