From 7858f9f725bfd355618992afda765bc61158b894 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 16 Dec 2023 13:10:36 +0100 Subject: [PATCH] Fix discrepancies with XGBRegressor and xgboost > 2 (#670) * Fix discrepancies with XGBRegressor and xgboost > 2 Signed-off-by: Xavier Dupre * improve ci version Signed-off-by: Xavier Dupre * fix many issues Signed-off-by: Xavier Dupre * more fixes Signed-off-by: Xavier Dupre * many fixes for xgboost Signed-off-by: Xavier Dupre * fix requiremnts. Signed-off-by: Xavier Dupre * fix converters for old versions of xgboost Signed-off-by: Xavier Dupre * fix pipeline and base_score Signed-off-by: Xavier Dupre * poisson Signed-off-by: Xavier Dupre --------- Signed-off-by: Xavier Dupre --- .azure-pipelines/linux-conda-CI.yml | 20 ++- .azure-pipelines/win32-conda-CI.yml | 7 + CHANGELOGS.md | 2 + onnxmltools/convert/xgboost/_parse.py | 13 +- onnxmltools/convert/xgboost/common.py | 26 +++- .../xgboost/operator_converters/XGBoost.py | 32 +++-- .../xgboost/shape_calculators/Classifier.py | 20 ++- onnxmltools/utils/utils_backend.py | 2 + requirements-dev.txt | 3 +- requirements.txt | 5 +- tests/baseline/test_convert_baseline.py | 8 +- tests/h2o/test_h2o_converters.py | 1 - tests/xgboost/test_xgboost_13.py | 64 --------- tests/xgboost/test_xgboost_converters.py | 60 ++++++-- .../test_xgboost_converters_base_score.py | 136 ++++++++++++++++++ 15 files changed, 286 insertions(+), 113 deletions(-) delete mode 100644 tests/xgboost/test_xgboost_13.py create mode 100644 tests/xgboost/test_xgboost_converters_base_score.py diff --git a/.azure-pipelines/linux-conda-CI.yml b/.azure-pipelines/linux-conda-CI.yml index 20f7028c..eff402ed 100644 --- a/.azure-pipelines/linux-conda-CI.yml +++ b/.azure-pipelines/linux-conda-CI.yml @@ -15,13 +15,23 @@ jobs: strategy: matrix: + Python311-1150-RT1163-xgb2-lgbm40: + python.version: '3.11' + ONNX_PATH: 'onnx==1.15.0' + ONNXRT_PATH: 'onnxruntime==1.16.3' + COREML_PATH: NONE + lightgbm.version: '>=4.0' + xgboost.version: '>=2' + numpy.version: '' + scipy.version: '' + Python311-1150-RT1160-xgb175-lgbm40: python.version: '3.11' ONNX_PATH: 'onnx==1.15.0' ONNXRT_PATH: 'onnxruntime==1.16.2' COREML_PATH: NONE lightgbm.version: '>=4.0' - xgboost.version: '>=1.7.5' + xgboost.version: '==1.7.5' numpy.version: '' scipy.version: '' @@ -31,7 +41,7 @@ jobs: ONNXRT_PATH: 'onnxruntime==1.16.2' COREML_PATH: NONE lightgbm.version: '>=4.0' - xgboost.version: '>=1.7.5' + xgboost.version: '==1.7.5' numpy.version: '' scipy.version: '' @@ -41,7 +51,7 @@ jobs: ONNXRT_PATH: 'onnxruntime==1.15.1' COREML_PATH: NONE lightgbm.version: '<4.0' - xgboost.version: '>=1.7.5' + xgboost.version: '==1.7.5' numpy.version: '' scipy.version: '' @@ -51,7 +61,7 @@ jobs: ONNXRT_PATH: 'onnxruntime==1.14.0' COREML_PATH: NONE lightgbm.version: '<4.0' - xgboost.version: '>=1.7.5' + xgboost.version: '==1.7.5' numpy.version: '' scipy.version: '' @@ -61,7 +71,7 @@ jobs: ONNXRT_PATH: 'onnxruntime==1.15.1' COREML_PATH: NONE lightgbm.version: '>=4.0' - xgboost.version: '>=1.7.5' + xgboost.version: '==1.7.5' numpy.version: '' scipy.version: '==1.8.0' diff --git a/.azure-pipelines/win32-conda-CI.yml b/.azure-pipelines/win32-conda-CI.yml index 09cc2e50..7ccc5f03 100644 --- a/.azure-pipelines/win32-conda-CI.yml +++ b/.azure-pipelines/win32-conda-CI.yml @@ -21,6 +21,7 @@ jobs: ONNXRT_PATH: 'onnxruntime==1.16.2' COREML_PATH: NONE numpy.version: '' + xgboost.version: '2.0.2' Python311-1141-RT1162: python.version: '3.11' @@ -28,6 +29,7 @@ jobs: ONNXRT_PATH: 'onnxruntime==1.16.2' COREML_PATH: NONE numpy.version: '' + xgboost.version: '1.7.5' Python310-1141-RT1151: python.version: '3.10' @@ -35,6 +37,7 @@ jobs: ONNXRT_PATH: 'onnxruntime==1.15.1' COREML_PATH: NONE numpy.version: '' + xgboost.version: '1.7.5' Python310-1141-RT1140: python.version: '3.10' @@ -42,6 +45,7 @@ jobs: ONNXRT_PATH: onnxruntime==1.14.0 COREML_PATH: NONE numpy.version: '' + xgboost.version: '1.7.5' Python39-1141-RT1140: python.version: '3.9' @@ -49,6 +53,7 @@ jobs: ONNXRT_PATH: onnxruntime==1.14.0 COREML_PATH: NONE numpy.version: '' + xgboost.version: '1.7.5' maxParallel: 3 @@ -74,6 +79,8 @@ jobs: - script: | call activate py$(python.version) python -m pip install --upgrade scikit-learn + python -m pip install --upgrade lightgbm + python -m pip install "xgboost==$(xgboost.version)" displayName: 'Install scikit-learn' - script: | diff --git a/CHANGELOGS.md b/CHANGELOGS.md index d0151fdf..514e7508 100644 --- a/CHANGELOGS.md +++ b/CHANGELOGS.md @@ -2,6 +2,8 @@ ## 1.12.0 +* Fix discrepancies with XGBRegressor and xgboost > 2 + [#670](https://github.com/onnx/onnxmltools/pull/670) * Support count:poisson for XGBRegressor [#666](https://github.com/onnx/onnxmltools/pull/666) * Supports XGBRFClassifier and XGBRFRegressor diff --git a/onnxmltools/convert/xgboost/_parse.py b/onnxmltools/convert/xgboost/_parse.py index 9e67209a..af4df543 100644 --- a/onnxmltools/convert/xgboost/_parse.py +++ b/onnxmltools/convert/xgboost/_parse.py @@ -70,12 +70,15 @@ def _get_attributes(booster): ntrees = trees // num_class if num_class > 0 else trees else: trees = len(res) - ntrees = booster.best_ntree_limit - num_class = trees // ntrees + ntrees = getattr(booster, "best_ntree_limit", trees) + config = json.loads(booster.save_config())["learner"]["learner_model_param"] + num_class = int(config["num_class"]) if "num_class" in config else 0 + if num_class == 0 and ntrees > 0: + num_class = trees // ntrees if num_class == 0: raise RuntimeError( - "Unable to retrieve the number of classes, trees=%d, ntrees=%d." - % (trees, ntrees) + f"Unable to retrieve the number of classes, num_class={num_class}, " + f"trees={trees}, ntrees={ntrees}, config={config}." ) kwargs = atts.copy() @@ -137,7 +140,7 @@ def __init__(self, booster): self.operator_name = "XGBRegressor" def get_xgb_params(self): - return self.kwargs + return {k: v for k, v in self.kwargs.items() if v is not None} def get_booster(self): return self.booster_ diff --git a/onnxmltools/convert/xgboost/common.py b/onnxmltools/convert/xgboost/common.py index d2eacdbb..8f267f58 100644 --- a/onnxmltools/convert/xgboost/common.py +++ b/onnxmltools/convert/xgboost/common.py @@ -3,6 +3,7 @@ """ Common function to converters and shape calculators. """ +import json def get_xgb_params(xgb_node): @@ -15,8 +16,29 @@ def get_xgb_params(xgb_node): else: # XGBoost < 0.7 params = xgb_node.__dict__ - + if hasattr("xgb_node", "save_config"): + config = json.loads(xgb_node.save_config()) + else: + config = json.loads(xgb_node.get_booster().save_config()) + params = {k: v for k, v in params.items() if v is not None} + num_class = int(config["learner"]["learner_model_param"]["num_class"]) + if num_class > 0: + params["num_class"] = num_class if "n_estimators" not in params and hasattr(xgb_node, "n_estimators"): # xgboost >= 1.0.2 - params["n_estimators"] = xgb_node.n_estimators + if xgb_node.n_estimators is not None: + params["n_estimators"] = xgb_node.n_estimators + if "base_score" in config["learner"]["learner_model_param"]: + bs = float(config["learner"]["learner_model_param"]["base_score"]) + # xgboost >= 2.0 + params["base_score"] = bs return params + + +def get_n_estimators_classifier(xgb_node, params, js_trees): + if "n_estimators" not in params: + num_class = params.get("num_class", 0) + if num_class == 0: + return len(js_trees) + return len(js_trees) // num_class + return params["n_estimators"] diff --git a/onnxmltools/convert/xgboost/operator_converters/XGBoost.py b/onnxmltools/convert/xgboost/operator_converters/XGBoost.py index 42a3964a..a9f31211 100644 --- a/onnxmltools/convert/xgboost/operator_converters/XGBoost.py +++ b/onnxmltools/convert/xgboost/operator_converters/XGBoost.py @@ -10,7 +10,7 @@ except ImportError: XGBRFClassifier = None from ...common._registration import register_converter -from ..common import get_xgb_params +from ..common import get_xgb_params, get_n_estimators_classifier class XGBConverter: @@ -161,8 +161,7 @@ def _fill_node_attributes( false_child_id=remap[jsnode["no"]], # ['children'][1]['nodeid'], weights=None, weight_id_bias=None, - missing=jsnode.get("missing", -1) - == jsnode["yes"], # ['children'][0]['nodeid'], + missing=jsnode.get("missing", -1) == jsnode["yes"], hitrate=jsnode.get("cover", 0), ) @@ -265,8 +264,8 @@ def convert(scope, operator, container): ) if objective == "count:poisson": - cst = scope.get_unique_variable_name("half") - container.add_initializer(cst, TensorProto.FLOAT, [1], [0.5]) + cst = scope.get_unique_variable_name("poisson") + container.add_initializer(cst, TensorProto.FLOAT, [1], [base_score]) new_name = scope.get_unique_variable_name("exp") container.add_node("Exp", names, [new_name]) container.add_node("Mul", [new_name, cst], operator.output_full_names) @@ -293,11 +292,18 @@ def convert(scope, operator, container): objective, base_score, js_trees = XGBConverter.common_members(xgb_node, inputs) params = XGBConverter.get_xgb_params(xgb_node) + n_estimators = get_n_estimators_classifier(xgb_node, params, js_trees) + num_class = params.get("num_class", None) + attr_pairs = XGBClassifierConverter._get_default_tree_attribute_pairs() XGBConverter.fill_tree_attributes( js_trees, attr_pairs, [1 for _ in js_trees], True ) - ncl = (max(attr_pairs["class_treeids"]) + 1) // params["n_estimators"] + if num_class is not None: + ncl = num_class + n_estimators = len(js_trees) // ncl + else: + ncl = (max(attr_pairs["class_treeids"]) + 1) // n_estimators bst = xgb_node.get_booster() best_ntree_limit = getattr(bst, "best_ntree_limit", len(js_trees)) * ncl @@ -310,6 +316,7 @@ def convert(scope, operator, container): if len(attr_pairs["class_treeids"]) == 0: raise RuntimeError("XGBoost model is empty.") + if ncl <= 1: ncl = 2 if objective != "binary:hinge": @@ -317,8 +324,9 @@ def convert(scope, operator, container): attr_pairs["post_transform"] = "LOGISTIC" attr_pairs["class_ids"] = [0 for v in attr_pairs["class_treeids"]] if js_trees[0].get("leaf", None) == 0: - attr_pairs["base_values"] = [0.5] + attr_pairs["base_values"] = [base_score] elif base_score != 0.5: + # 0.5 -> cst = 0 cst = -np.log(1 / np.float32(base_score) - 1.0) attr_pairs["base_values"] = [cst] else: @@ -330,8 +338,10 @@ def convert(scope, operator, container): attr_pairs["class_ids"] = [v % ncl for v in attr_pairs["class_treeids"]] classes = xgb_node.classes_ - if np.issubdtype(classes.dtype, np.floating) or np.issubdtype( - classes.dtype, np.integer + if ( + np.issubdtype(classes.dtype, np.floating) + or np.issubdtype(classes.dtype, np.integer) + or np.issubdtype(classes.dtype, np.bool_) ): attr_pairs["classlabels_int64s"] = classes.astype("int") else: @@ -373,7 +383,7 @@ def convert(scope, operator, container): "Where", [greater, one, zero], operator.output_full_names[1] ) elif objective in ("multi:softprob", "multi:softmax"): - ncl = len(js_trees) // params["n_estimators"] + ncl = len(js_trees) // n_estimators if objective == "multi:softmax": attr_pairs["post_transform"] = "NONE" container.add_node( @@ -385,7 +395,7 @@ def convert(scope, operator, container): **attr_pairs, ) elif objective == "reg:logistic": - ncl = len(js_trees) // params["n_estimators"] + ncl = len(js_trees) // n_estimators if ncl == 1: ncl = 2 container.add_node( diff --git a/onnxmltools/convert/xgboost/shape_calculators/Classifier.py b/onnxmltools/convert/xgboost/shape_calculators/Classifier.py index a543e103..e44d5a52 100644 --- a/onnxmltools/convert/xgboost/shape_calculators/Classifier.py +++ b/onnxmltools/convert/xgboost/shape_calculators/Classifier.py @@ -8,7 +8,7 @@ Int64TensorType, StringTensorType, ) -from ..common import get_xgb_params +from ..common import get_xgb_params, get_n_estimators_classifier def calculate_xgboost_classifier_output_shapes(operator): @@ -22,18 +22,26 @@ def calculate_xgboost_classifier_output_shapes(operator): params = get_xgb_params(xgb_node) booster = xgb_node.get_booster() booster.attributes() - ntrees = len(booster.get_dump(with_stats=True, dump_format="json")) + js_trees = booster.get_dump(with_stats=True, dump_format="json") + ntrees = len(js_trees) objective = params["objective"] + n_estimators = get_n_estimators_classifier(xgb_node, params, js_trees) + num_class = params.get("num_class", None) - if objective == "binary:logistic": + if num_class is not None: + ncl = num_class + n_estimators = ntrees // ncl + elif objective == "binary:logistic": ncl = 2 else: - ncl = ntrees // params["n_estimators"] + ncl = ntrees // n_estimators if objective == "reg:logistic" and ncl == 1: ncl = 2 classes = xgb_node.classes_ - if np.issubdtype(classes.dtype, np.floating) or np.issubdtype( - classes.dtype, np.integer + if ( + np.issubdtype(classes.dtype, np.floating) + or np.issubdtype(classes.dtype, np.integer) + or np.issubdtype(classes.dtype, np.bool_) ): operator.outputs[0].type = Int64TensorType(shape=[N]) else: diff --git a/onnxmltools/utils/utils_backend.py b/onnxmltools/utils/utils_backend.py index 84dbe146..f7ab5b42 100644 --- a/onnxmltools/utils/utils_backend.py +++ b/onnxmltools/utils/utils_backend.py @@ -188,6 +188,8 @@ def compare_outputs(expected, output, **kwargs): Disc = kwargs.pop("Disc", False) Mism = kwargs.pop("Mism", False) Opp = kwargs.pop("Opp", False) + if hasattr(expected, "dtype") and expected.dtype == numpy.bool_: + expected = expected.astype(numpy.int64) if Opp and not NoProb: raise ValueError("Opp is only available if NoProb is True") diff --git a/requirements-dev.txt b/requirements-dev.txt index 1345c554..cfd46ddc 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -15,6 +15,7 @@ pytest-spark ruff scikit-learn>=1.2.0 scipy +skl2onnx wheel -xgboost==1.7.5 +xgboost onnxruntime diff --git a/requirements.txt b/requirements.txt index 361b3238..308201c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,2 @@ -numpy -onnx -skl2onnx +numpy +onnx diff --git a/tests/baseline/test_convert_baseline.py b/tests/baseline/test_convert_baseline.py index 3b56bbba..44e6acd4 100644 --- a/tests/baseline/test_convert_baseline.py +++ b/tests/baseline/test_convert_baseline.py @@ -45,11 +45,11 @@ def get_diff(self, input_file, ref_file): def normalize_diff(self, diff): invalid_comparisons = [] invalid_comparisons.append( - re.compile('producer_version: "\d+\.\d+\.\d+\.\d+.*') + re.compile('producer_version: "\\d+\\.\\d+\\.\\d+\\.\\d+.*') ) - invalid_comparisons.append(re.compile('\s+name: ".*')) - invalid_comparisons.append(re.compile("ir_version: \d+")) - invalid_comparisons.append(re.compile("\s+")) + invalid_comparisons.append(re.compile('\\s+name: ".*')) + invalid_comparisons.append(re.compile("ir_version: \\d+")) + invalid_comparisons.append(re.compile("\\s+")) valid_diff = set() for line in diff: if any(comparison.match(line) for comparison in invalid_comparisons): diff --git a/tests/h2o/test_h2o_converters.py b/tests/h2o/test_h2o_converters.py index 100ec5cc..56ecb832 100644 --- a/tests/h2o/test_h2o_converters.py +++ b/tests/h2o/test_h2o_converters.py @@ -223,7 +223,6 @@ def test_h2o_classifier_multi_cat(self): train, test = _prepare_one_hot("airlines.csv", y) gbm = H2OGradientBoostingEstimator(ntrees=8, max_depth=5) mojo_path = _make_mojo(gbm, train, y=train.columns.index(y)) - print("****", mojo_path) onnx_model = _convert_mojo(mojo_path) self.assertIsNot(onnx_model, None) dump_data_and_model( diff --git a/tests/xgboost/test_xgboost_13.py b/tests/xgboost/test_xgboost_13.py deleted file mode 100644 index 18d31bab..00000000 --- a/tests/xgboost/test_xgboost_13.py +++ /dev/null @@ -1,64 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -""" -Tests scilit-learn's tree-based methods' converters. -""" -import os -import unittest -import numpy as np -from numpy.testing import assert_almost_equal -import pandas -from sklearn.model_selection import train_test_split -from xgboost import XGBClassifier -from onnx.defs import onnx_opset_version -from onnxconverter_common.onnx_ex import DEFAULT_OPSET_NUMBER -from onnxmltools.convert import convert_xgboost -from onnxmltools.convert.common.data_types import FloatTensorType -from onnxruntime import InferenceSession - - -TARGET_OPSET = min(DEFAULT_OPSET_NUMBER, onnx_opset_version()) - - -class TestXGBoost13(unittest.TestCase): - def test_xgb_regressor(self): - this = os.path.dirname(__file__) - df = pandas.read_csv(os.path.join(this, "data_fail_empty.csv")) - X, y = df.drop("y", axis=1), df["y"] - X_train, X_test, y_train, y_test = train_test_split(X, y) - - clr = XGBClassifier( - max_delta_step=0, - tree_method="hist", - n_estimators=100, - booster="gbtree", - objective="binary:logistic", - eval_metric="logloss", - learning_rate=0.1, - gamma=10, - max_depth=7, - min_child_weight=50, - subsample=0.75, - colsample_bytree=0.75, - random_state=42, - verbosity=0, - ) - - clr.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=40) - - initial_type = [("float_input", FloatTensorType([None, 797]))] - onx = convert_xgboost( - clr, initial_types=initial_type, target_opset=TARGET_OPSET - ) - expected = clr.predict(X_test), clr.predict_proba(X_test) - sess = InferenceSession( - onx.SerializeToString(), providers=["CPUExecutionProvider"] - ) - X_test = X_test.values.astype(np.float32) - got = sess.run(None, {"float_input": X_test}) - assert_almost_equal(expected[1], got[1]) - assert_almost_equal(expected[0], got[0]) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/xgboost/test_xgboost_converters.py b/tests/xgboost/test_xgboost_converters.py index 4b1626ea..7a37dd3a 100644 --- a/tests/xgboost/test_xgboost_converters.py +++ b/tests/xgboost/test_xgboost_converters.py @@ -92,14 +92,14 @@ def test_xgb_regressor(self): def test_xgb_regressor_poisson(self): iris = load_diabetes() x = iris.data - y = iris.target + y = iris.target / 100 x_train, x_test, y_train, _ = train_test_split( - x, y, test_size=0.5, random_state=42 + x, y, test_size=0.5, random_state=17 ) for nest in [5, 50]: xgb = XGBRegressor( objective="count:poisson", - random_state=0, + random_state=5, max_depth=3, n_estimators=nest, ) @@ -116,7 +116,7 @@ def test_xgb_regressor_poisson(self): basename=f"SklearnXGBRegressorPoisson{nest}-Dec3", ) - def test_xgb_classifier(self): + def test_xgb0_classifier(self): xgb, x_test = _fit_classification_model(XGBClassifier(), 2) conv_model = convert_xgboost( xgb, @@ -198,7 +198,7 @@ def test_xgb_classifier_multi_discrete_int_labels(self): basename="SklearnXGBClassifierMultiDiscreteIntLabels", ) - def test_xgboost_booster_classifier_bin(self): + def test_xgb1_booster_classifier_bin(self): x, y = make_classification( n_classes=2, n_features=5, n_samples=100, random_state=42, n_informative=3 ) @@ -221,7 +221,7 @@ def test_xgboost_booster_classifier_bin(self): x_test.astype(np.float32), model, model_onnx, basename="XGBBoosterMCl" ) - def test_xgboost_booster_classifier_multiclass_softprob(self): + def test_xgb0_booster_classifier_multiclass_softprob(self): x, y = make_classification( n_classes=3, n_features=5, n_samples=100, random_state=42, n_informative=3 ) @@ -283,11 +283,12 @@ def test_xgboost_booster_classifier_multiclass_softmax(self): basename="XGBBoosterMClSoftMax", ) - def test_xgboost_booster_classifier_reg(self): + def test_xgboost_booster_reg(self): x, y = make_classification( n_classes=2, n_features=5, n_samples=100, random_state=42, n_informative=3 ) y = y.astype(np.float32) + 0.567 + print(y) x_train, x_test, y_train, _ = train_test_split( x, y, test_size=0.5, random_state=42 ) @@ -431,7 +432,7 @@ def test_xgboost_example_mnist(self): X_test.astype(np.float32), clf, onnx_model, basename="XGBoostExample" ) - def test_xgb_empty_tree(self): + def test_xgb0_empty_tree_classifier(self): xgb = XGBClassifier(n_estimators=2, max_depth=2) # simple dataset @@ -451,7 +452,7 @@ def test_xgb_empty_tree(self): assert_almost_equal(xgb.predict_proba(X), res[1]) assert_almost_equal(xgb.predict(X), res[0]) - def test_xgb_best_tree_limit(self): + def test_xgb_best_tree_limit_classifier(self): # Train iris = load_iris() X, y = iris.data, iris.target @@ -499,7 +500,7 @@ def test_xgb_best_tree_limit(self): ) assert_almost_equal(bst_loaded.predict(dtest), res[0]) - def test_onnxrt_python_xgbclassifier(self): + def test_xgb_classifier(self): x = np.random.randn(100, 10).astype(np.float32) y = ((x.sum(axis=1) + np.random.randn(x.shape[0]) / 50 + 0.5) >= 0).astype( np.int64 @@ -675,7 +676,44 @@ def test_xgb_classifier_hinge(self): x_test, xgb, conv_model, basename="SklearnXGBClassifierHinge" ) + def test_xgb_classifier_13(self): + this = os.path.dirname(__file__) + df = pandas.read_csv(os.path.join(this, "data_fail_empty.csv")) + X, y = df.drop("y", axis=1), df["y"] + X_train, X_test, y_train, y_test = train_test_split(X, y) + + clr = XGBClassifier( + max_delta_step=0, + tree_method="hist", + n_estimators=100, + booster="gbtree", + objective="binary:logistic", + eval_metric="logloss", + learning_rate=0.1, + gamma=10, + max_depth=7, + min_child_weight=50, + subsample=0.75, + colsample_bytree=0.75, + random_state=42, + verbosity=0, + ) + + clr.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=40) + + initial_type = [("float_input", FloatTensorType([None, 797]))] + onx = convert_xgboost( + clr, initial_types=initial_type, target_opset=TARGET_OPSET + ) + expected = clr.predict(X_test), clr.predict_proba(X_test) + sess = InferenceSession( + onx.SerializeToString(), providers=["CPUExecutionProvider"] + ) + X_test = X_test.values.astype(np.float32) + got = sess.run(None, {"float_input": X_test}) + assert_almost_equal(expected[1], got[1]) + assert_almost_equal(expected[0], got[0]) + if __name__ == "__main__": - TestXGBoostModels().test_xgb_regressor_poisson() unittest.main(verbosity=2) diff --git a/tests/xgboost/test_xgboost_converters_base_score.py b/tests/xgboost/test_xgboost_converters_base_score.py new file mode 100644 index 00000000..27d40172 --- /dev/null +++ b/tests/xgboost/test_xgboost_converters_base_score.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 + +import unittest +import numpy as np +import scipy +from numpy.testing import assert_almost_equal +from sklearn.datasets import make_regression +from xgboost import XGBClassifier, XGBRegressor +from onnx.defs import onnx_opset_version +from onnxconverter_common.onnx_ex import DEFAULT_OPSET_NUMBER +from onnxmltools.convert import convert_xgboost +from onnxmltools.convert.common.data_types import FloatTensorType +from onnxruntime import InferenceSession + + +TARGET_OPSET = min(DEFAULT_OPSET_NUMBER, onnx_opset_version()) + + +class TestXGBoostModelsBaseScore(unittest.TestCase): + def test_xgbregressor_sparse_base_score(self): + X, y = make_regression(n_samples=200, n_features=10, random_state=0) + mask = np.random.randint(0, 50, size=(X.shape)) != 0 + X[mask] = 0 + y = (y + mask.sum(axis=1, keepdims=0)).astype(np.float32) + X_sp = scipy.sparse.coo_matrix(X) + X = X.astype(np.float32) + + rf = XGBRegressor(n_estimators=3, max_depth=4, random_state=0, base_score=0.5) + rf.fit(X_sp, y) + expected = rf.predict(X).astype(np.float32).reshape((-1, 1)) + expected_sparse = rf.predict(X_sp).astype(np.float32).reshape((-1, 1)) + diff = np.abs(expected - expected_sparse) + self.assertNotEqual(diff.min(), diff.max()) + + onx = convert_xgboost( + rf, + initial_types=[("X", FloatTensorType(shape=[None, None]))], + target_opset=TARGET_OPSET, + ) + feeds = {"X": X} + + sess = InferenceSession( + onx.SerializeToString(), providers=["CPUExecutionProvider"] + ) + got = sess.run(None, feeds)[0] + assert_almost_equal(expected, got, decimal=4) + + def test_xgbregressor_sparse_no_base_score(self): + X, y = make_regression(n_samples=200, n_features=10, random_state=0) + mask = np.random.randint(0, 50, size=(X.shape)) != 0 + X[mask] = 0 + y = (y + mask.sum(axis=1, keepdims=0)).astype(np.float32) + X_sp = scipy.sparse.coo_matrix(X) + X = X.astype(np.float32) + + rf = XGBRegressor(n_estimators=3, max_depth=4, random_state=0) + rf.fit(X_sp, y) + expected = rf.predict(X).astype(np.float32).reshape((-1, 1)) + expected_sparse = rf.predict(X_sp).astype(np.float32).reshape((-1, 1)) + diff = np.abs(expected - expected_sparse) + self.assertNotEqual(diff.min(), diff.max()) + + onx = convert_xgboost( + rf, + initial_types=[("X", FloatTensorType(shape=[None, None]))], + target_opset=TARGET_OPSET, + ) + feeds = {"X": X} + + sess = InferenceSession( + onx.SerializeToString(), providers=["CPUExecutionProvider"] + ) + got = sess.run(None, feeds)[0] + assert_almost_equal(expected, got, decimal=4) + + def test_xgbclassifier_sparse_base_score(self): + X, y = make_regression(n_samples=200, n_features=10, random_state=0) + mask = np.random.randint(0, 50, size=(X.shape)) != 0 + X[mask] = 0 + y = (y + mask.sum(axis=1, keepdims=0)).astype(np.float32) + y = y >= y.mean() + X_sp = scipy.sparse.coo_matrix(X) + X = X.astype(np.float32) + + rf = XGBClassifier(n_estimators=3, max_depth=4, random_state=0, base_score=0.5) + rf.fit(X_sp, y) + expected = rf.predict_proba(X).astype(np.float32).reshape((-1, 1)) + expected_sparse = rf.predict_proba(X_sp).astype(np.float32).reshape((-1, 1)) + diff = np.abs(expected - expected_sparse) + self.assertNotEqual(diff.min(), diff.max()) + + onx = convert_xgboost( + rf, + initial_types=[("X", FloatTensorType(shape=[None, None]))], + target_opset=TARGET_OPSET, + ) + feeds = {"X": X} + + sess = InferenceSession( + onx.SerializeToString(), providers=["CPUExecutionProvider"] + ) + got = sess.run(None, feeds)[1] + assert_almost_equal(expected.reshape((-1, 2)), got, decimal=4) + + def test_xgbclassifier_sparse_no_base_score(self): + X, y = make_regression(n_samples=200, n_features=10, random_state=0) + mask = np.random.randint(0, 50, size=(X.shape)) != 0 + X[mask] = 0 + y = (y + mask.sum(axis=1, keepdims=0)).astype(np.float32) + y = y >= y.mean() + X_sp = scipy.sparse.coo_matrix(X) + X = X.astype(np.float32) + + rf = XGBClassifier(n_estimators=3, max_depth=4, random_state=0) + rf.fit(X_sp, y) + expected = rf.predict_proba(X).astype(np.float32).reshape((-1, 1)) + expected_sparse = rf.predict_proba(X_sp).astype(np.float32).reshape((-1, 1)) + diff = np.abs(expected - expected_sparse) + self.assertNotEqual(diff.min(), diff.max()) + + onx = convert_xgboost( + rf, + initial_types=[("X", FloatTensorType(shape=[None, None]))], + target_opset=TARGET_OPSET, + ) + feeds = {"X": X} + + sess = InferenceSession( + onx.SerializeToString(), providers=["CPUExecutionProvider"] + ) + got = sess.run(None, feeds)[1] + assert_almost_equal(expected.reshape((-1, 2)), got, decimal=4) + + +if __name__ == "__main__": + unittest.main(verbosity=2)