Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <xadupre@microsoft.com>
  • Loading branch information
xadupre committed Dec 11, 2023
1 parent b03a98e commit 47cbee2
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 10 deletions.
3 changes: 2 additions & 1 deletion onnxmltools/convert/xgboost/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def _get_attributes(booster):
config = json.loads(booster.save_config())["learner"]["learner_model_param"]
if "num_class" in config:
num_class = int(config["num_class"])
ntrees = len(res) // num_class
ntrees = len(res)
num_class = 1
else:
trees = len(res)
if hasattr(booster, "best_ntree_limit"):
Expand Down
11 changes: 6 additions & 5 deletions onnxmltools/convert/xgboost/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,19 @@ def get_xgb_params(xgb_node):
else:
# XGBoost < 0.7
params = xgb_node.__dict__
if hasattr("xgb_node", "save_config"):
config = json.loads(xgb_node.save_config())
else:
config = json.loads(xgb_node.get_booster().save_config())
num_class = int(config["learner"]["learner_model_param"]["num_class"])
params = {k: v for k, v in params.items() if v is not None}
params["num_class"] = num_class
if "n_estimators" not in params and hasattr(xgb_node, "n_estimators"):
# xgboost >= 1.0.2
if xgb_node.n_estimators is not None:
params["n_estimators"] = xgb_node.n_estimators
if params.get("base_score", None) is None:
# xgboost >= 2.0
if hasattr("xgb_node", "save_config"):
config = json.loads(xgb_node.save_config())
else:
config = json.loads(xgb_node.get_booster().save_config())

params["base_score"] = float(
config["learner"]["learner_model_param"]["base_score"]
)
Expand Down
14 changes: 11 additions & 3 deletions onnxmltools/convert/xgboost/operator_converters/XGBoost.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,12 @@ def convert(scope, operator, container):
XGBConverter.fill_tree_attributes(
js_trees, attr_pairs, [1 for _ in js_trees], True
)
ncl = (max(attr_pairs["class_treeids"]) + 1) // n_estimators
if "num_class" in params:
ncl = params["num_class"]
n_estimators = len(js_trees) // ncl
else:
ncl = (max(attr_pairs["class_treeids"]) + 1) // n_estimators
print("**", params)

bst = xgb_node.get_booster()
best_ntree_limit = getattr(bst, "best_ntree_limit", len(js_trees)) * ncl
Expand All @@ -312,6 +317,7 @@ def convert(scope, operator, container):

if len(attr_pairs["class_treeids"]) == 0:
raise RuntimeError("XGBoost model is empty.")

if ncl <= 1:
ncl = 2
if objective != "binary:hinge":
Expand All @@ -332,8 +338,10 @@ def convert(scope, operator, container):
attr_pairs["class_ids"] = [v % ncl for v in attr_pairs["class_treeids"]]

classes = xgb_node.classes_
if np.issubdtype(classes.dtype, np.floating) or np.issubdtype(
classes.dtype, np.integer
if (
np.issubdtype(classes.dtype, np.floating)
or np.issubdtype(classes.dtype, np.integer)
or np.issubdtype(classes.dtype, np.bool_)
):
attr_pairs["classlabels_int64s"] = classes.astype("int")
else:
Expand Down
8 changes: 7 additions & 1 deletion onnxmltools/convert/xgboost/shape_calculators/Classifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

import json
import numpy as np
from ...common._registration import register_shape_calculator
from ...common.utils import check_input_and_output_numbers, check_input_and_output_types
Expand All @@ -26,8 +27,13 @@ def calculate_xgboost_classifier_output_shapes(operator):
ntrees = len(js_trees)
objective = params["objective"]
n_estimators = get_n_estimators_classifier(xgb_node, params, js_trees)
config = json.loads(xgb_node.get_booster().save_config())
num_class = int(config["learner"]["learner_model_param"]["num_class"])

if objective == "binary:logistic":
if num_class is not None:
ncl = num_class
n_estimators = ntrees // ncl
elif objective == "binary:logistic":
ncl = 2
else:
ncl = ntrees // n_estimators
Expand Down
2 changes: 2 additions & 0 deletions onnxmltools/utils/utils_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def compare_outputs(expected, output, **kwargs):
Disc = kwargs.pop("Disc", False)
Mism = kwargs.pop("Mism", False)
Opp = kwargs.pop("Opp", False)
if hasattr(expected, "dtype") and expected.dtype == numpy.bool_:
expected = expected.astype(numpy.int64)
if Opp and not NoProb:
raise ValueError("Opp is only available if NoProb is True")

Expand Down
1 change: 1 addition & 0 deletions tests/xgboost/test_xgboost_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,4 +677,5 @@ def test_xgb_classifier_hinge(self):


if __name__ == "__main__":
TestXGBoostModels().test_xgb_best_tree_limit()
unittest.main(verbosity=2)

0 comments on commit 47cbee2

Please sign in to comment.