Skip to content

Commit

Permalink
add unit test to check on XGBRFRegressor
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <xadupre@microsoft.com>
  • Loading branch information
xadupre committed Jan 4, 2024
1 parent 180e733 commit 6d1e625
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ build/
# test generated files
.pytest_cache
.cache
debug*
dump*
model*.json
tests/temp
tests/utils/models/coreml_OneHotEncoder_BikeSharing_new.json
tests/utils/models/coreml_OneHotEncoder_BikeSharing2.onnx
Expand Down
65 changes: 65 additions & 0 deletions tests/xgboost/test_xgboost_issue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# SPDX-License-Identifier: Apache-2.0
import unittest
import numpy as np
from numpy.testing import assert_almost_equal
from sklearn.datasets import load_iris
from xgboost import XGBClassifier, XGBRegressor, XGBRFClassifier, XGBRFRegressor
from onnx.defs import onnx_opset_version
from onnxconverter_common.onnx_ex import DEFAULT_OPSET_NUMBER
from onnxmltools.convert import convert_xgboost
from onnxmltools.convert.common.data_types import FloatTensorType
from onnxruntime import InferenceSession


TARGET_OPSET = min(DEFAULT_OPSET_NUMBER, onnx_opset_version())


class TestXGBoostIssue(unittest.TestCase):
def common_test(self, cls, n_estimators):
dataset = load_iris()
X, y = dataset.data, dataset.target
model = cls(
n_estimators=n_estimators,
learning_rate=1.0,
subsample=0.8,
colsample_bynode=0.8,
reg_lambda=1e-5,
)
model.fit(X, y)
data = np.random.rand(5, 4).astype(np.float32)
expected_labels = model.predict(data)
expected_probabilities = (
model.predict_proba(data) if hasattr(model, "predict_proba") else None
)

onnx_model = convert_xgboost(
model, initial_types=[("input", FloatTensorType(shape=[None, None]))]
)

session = InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)

if expected_probabilities is None:
(onnx_predictions,) = session.run(None, {"input": data})
assert_almost_equal(expected_labels, onnx_predictions.ravel())
else:
onnx_predictions, onnx_probabilities = session.run(None, {"input": data})
assert_almost_equal(expected_probabilities, onnx_probabilities)
assert_almost_equal(expected_labels, onnx_predictions.ravel())

def test_issue_663_classifier(self):
self.common_test(XGBClassifier, 1)
self.common_test(XGBRFClassifier, 1)
self.common_test(XGBClassifier, 2)
self.common_test(XGBRFClassifier, 2)

def test_issue_663_regressor(self):
self.common_test(XGBRegressor, 1)
self.common_test(XGBRFRegressor, 1)
self.common_test(XGBRegressor, 2)
self.common_test(XGBRFRegressor, 2)


if __name__ == "__main__":
unittest.main(verbosity=2)

0 comments on commit 6d1e625

Please sign in to comment.