diff --git a/onnxmltools/convert/xgboost/_parse.py b/onnxmltools/convert/xgboost/_parse.py index 9e67209a..bb200307 100644 --- a/onnxmltools/convert/xgboost/_parse.py +++ b/onnxmltools/convert/xgboost/_parse.py @@ -69,9 +69,19 @@ def _get_attributes(booster): except AttributeError: ntrees = trees // num_class if num_class > 0 else trees else: - trees = len(res) - ntrees = booster.best_ntree_limit - num_class = trees // ntrees + config = json.loads(booster.save_config())["learner"]["learner_model_param"] + if "num_class" in config: + num_class = int(config["num_class"]) + ntrees = len(res) // num_class + else: + trees = len(res) + if hasattr(booster, "best_ntree_limit"): + ntrees = booster.best_ntree_limit + elif hasattr(booster, "best_iteration"): + ntrees = booster.best_iteration + else: + raise RuntimeError("Unable to guess the number of classes.") + num_class = trees // ntrees if num_class == 0: raise RuntimeError( "Unable to retrieve the number of classes, trees=%d, ntrees=%d." @@ -137,7 +147,7 @@ def __init__(self, booster): self.operator_name = "XGBRegressor" def get_xgb_params(self): - return self.kwargs + return {k: v for k, v in self.kwargs.items() if v is not None} def get_booster(self): return self.booster_ diff --git a/onnxmltools/convert/xgboost/common.py b/onnxmltools/convert/xgboost/common.py index 5b551c51..bfe93955 100644 --- a/onnxmltools/convert/xgboost/common.py +++ b/onnxmltools/convert/xgboost/common.py @@ -16,10 +16,11 @@ def get_xgb_params(xgb_node): else: # XGBoost < 0.7 params = xgb_node.__dict__ - + params = {k: v for k, v in params.items() if v is not None} if "n_estimators" not in params and hasattr(xgb_node, "n_estimators"): # xgboost >= 1.0.2 - params["n_estimators"] = xgb_node.n_estimators + if xgb_node.n_estimators is not None: + params["n_estimators"] = xgb_node.n_estimators if params.get("base_score", None) is None: # xgboost >= 2.0 if hasattr("xgb_node", "save_config"): @@ -31,3 +32,13 @@ def get_xgb_params(xgb_node): config["learner"]["learner_model_param"]["base_score"] ) return params + + +def get_n_estimators_classifier(xgb_node, params, js_trees): + if "n_estimators" not in params: + config = json.loads(xgb_node.get_booster().save_config()) + num_class = int(config["learner"]["learner_model_param"]["num_class"]) + if num_class == 0: + return len(js_trees) + return len(js_trees) // num_class + return params["n_estimators"] diff --git a/onnxmltools/convert/xgboost/operator_converters/XGBoost.py b/onnxmltools/convert/xgboost/operator_converters/XGBoost.py index 42a3964a..c48ac73c 100644 --- a/onnxmltools/convert/xgboost/operator_converters/XGBoost.py +++ b/onnxmltools/convert/xgboost/operator_converters/XGBoost.py @@ -10,7 +10,7 @@ except ImportError: XGBRFClassifier = None from ...common._registration import register_converter -from ..common import get_xgb_params +from ..common import get_xgb_params, get_n_estimators_classifier class XGBConverter: @@ -293,11 +293,13 @@ def convert(scope, operator, container): objective, base_score, js_trees = XGBConverter.common_members(xgb_node, inputs) params = XGBConverter.get_xgb_params(xgb_node) + n_estimators = get_n_estimators_classifier(xgb_node, params, js_trees) + attr_pairs = XGBClassifierConverter._get_default_tree_attribute_pairs() XGBConverter.fill_tree_attributes( js_trees, attr_pairs, [1 for _ in js_trees], True ) - ncl = (max(attr_pairs["class_treeids"]) + 1) // params["n_estimators"] + ncl = (max(attr_pairs["class_treeids"]) + 1) // n_estimators bst = xgb_node.get_booster() best_ntree_limit = getattr(bst, "best_ntree_limit", len(js_trees)) * ncl @@ -373,7 +375,7 @@ def convert(scope, operator, container): "Where", [greater, one, zero], operator.output_full_names[1] ) elif objective in ("multi:softprob", "multi:softmax"): - ncl = len(js_trees) // params["n_estimators"] + ncl = len(js_trees) // n_estimators if objective == "multi:softmax": attr_pairs["post_transform"] = "NONE" container.add_node( @@ -385,7 +387,7 @@ def convert(scope, operator, container): **attr_pairs, ) elif objective == "reg:logistic": - ncl = len(js_trees) // params["n_estimators"] + ncl = len(js_trees) // n_estimators if ncl == 1: ncl = 2 container.add_node( diff --git a/onnxmltools/convert/xgboost/shape_calculators/Classifier.py b/onnxmltools/convert/xgboost/shape_calculators/Classifier.py index a543e103..87b3990b 100644 --- a/onnxmltools/convert/xgboost/shape_calculators/Classifier.py +++ b/onnxmltools/convert/xgboost/shape_calculators/Classifier.py @@ -8,7 +8,7 @@ Int64TensorType, StringTensorType, ) -from ..common import get_xgb_params +from ..common import get_xgb_params, get_n_estimators_classifier def calculate_xgboost_classifier_output_shapes(operator): @@ -22,13 +22,15 @@ def calculate_xgboost_classifier_output_shapes(operator): params = get_xgb_params(xgb_node) booster = xgb_node.get_booster() booster.attributes() - ntrees = len(booster.get_dump(with_stats=True, dump_format="json")) + js_trees = booster.get_dump(with_stats=True, dump_format="json") + ntrees = len(js_trees) objective = params["objective"] + n_estimators = get_n_estimators_classifier(xgb_node, params, js_trees) if objective == "binary:logistic": ncl = 2 else: - ncl = ntrees // params["n_estimators"] + ncl = ntrees // n_estimators if objective == "reg:logistic" and ncl == 1: ncl = 2 classes = xgb_node.classes_ diff --git a/tests/h2o/test_h2o_converters.py b/tests/h2o/test_h2o_converters.py index 100ec5cc..56ecb832 100644 --- a/tests/h2o/test_h2o_converters.py +++ b/tests/h2o/test_h2o_converters.py @@ -223,7 +223,6 @@ def test_h2o_classifier_multi_cat(self): train, test = _prepare_one_hot("airlines.csv", y) gbm = H2OGradientBoostingEstimator(ntrees=8, max_depth=5) mojo_path = _make_mojo(gbm, train, y=train.columns.index(y)) - print("****", mojo_path) onnx_model = _convert_mojo(mojo_path) self.assertIsNot(onnx_model, None) dump_data_and_model( diff --git a/tests/xgboost/test_xgboost_converters.py b/tests/xgboost/test_xgboost_converters.py index 4b1626ea..73d5cf0e 100644 --- a/tests/xgboost/test_xgboost_converters.py +++ b/tests/xgboost/test_xgboost_converters.py @@ -677,5 +677,4 @@ def test_xgb_classifier_hinge(self): if __name__ == "__main__": - TestXGBoostModels().test_xgb_regressor_poisson() unittest.main(verbosity=2)