Skip to content

Commit

Permalink
Fix discrepancies with XGBRegressor and xgboost > 2 (#670)
Browse files Browse the repository at this point in the history
* Fix discrepancies with XGBRegressor and xgboost > 2

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* improve ci version

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* fix many issues

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* more fixes

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* many fixes for xgboost

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* fix requiremnts.

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* fix converters for old versions of xgboost

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* fix pipeline and base_score

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

* poisson

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>

---------

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>
  • Loading branch information
xadupre authored Dec 16, 2023
1 parent 761c1cd commit 7858f9f
Show file tree
Hide file tree
Showing 15 changed files with 286 additions and 113 deletions.
20 changes: 15 additions & 5 deletions .azure-pipelines/linux-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,23 @@ jobs:
strategy:
matrix:

Python311-1150-RT1163-xgb2-lgbm40:
python.version: '3.11'
ONNX_PATH: 'onnx==1.15.0'
ONNXRT_PATH: 'onnxruntime==1.16.3'
COREML_PATH: NONE
lightgbm.version: '>=4.0'
xgboost.version: '>=2'
numpy.version: ''
scipy.version: ''

Python311-1150-RT1160-xgb175-lgbm40:
python.version: '3.11'
ONNX_PATH: 'onnx==1.15.0'
ONNXRT_PATH: 'onnxruntime==1.16.2'
COREML_PATH: NONE
lightgbm.version: '>=4.0'
xgboost.version: '>=1.7.5'
xgboost.version: '==1.7.5'
numpy.version: ''
scipy.version: ''

Expand All @@ -31,7 +41,7 @@ jobs:
ONNXRT_PATH: 'onnxruntime==1.16.2'
COREML_PATH: NONE
lightgbm.version: '>=4.0'
xgboost.version: '>=1.7.5'
xgboost.version: '==1.7.5'
numpy.version: ''
scipy.version: ''

Expand All @@ -41,7 +51,7 @@ jobs:
ONNXRT_PATH: 'onnxruntime==1.15.1'
COREML_PATH: NONE
lightgbm.version: '<4.0'
xgboost.version: '>=1.7.5'
xgboost.version: '==1.7.5'
numpy.version: ''
scipy.version: ''

Expand All @@ -51,7 +61,7 @@ jobs:
ONNXRT_PATH: 'onnxruntime==1.14.0'
COREML_PATH: NONE
lightgbm.version: '<4.0'
xgboost.version: '>=1.7.5'
xgboost.version: '==1.7.5'
numpy.version: ''
scipy.version: ''

Expand All @@ -61,7 +71,7 @@ jobs:
ONNXRT_PATH: 'onnxruntime==1.15.1'
COREML_PATH: NONE
lightgbm.version: '>=4.0'
xgboost.version: '>=1.7.5'
xgboost.version: '==1.7.5'
numpy.version: ''
scipy.version: '==1.8.0'

Expand Down
7 changes: 7 additions & 0 deletions .azure-pipelines/win32-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,39 @@ jobs:
ONNXRT_PATH: 'onnxruntime==1.16.2'
COREML_PATH: NONE
numpy.version: ''
xgboost.version: '2.0.2'

Python311-1141-RT1162:
python.version: '3.11'
ONNX_PATH: 'onnx==1.14.1'
ONNXRT_PATH: 'onnxruntime==1.16.2'
COREML_PATH: NONE
numpy.version: ''
xgboost.version: '1.7.5'

Python310-1141-RT1151:
python.version: '3.10'
ONNX_PATH: 'onnx==1.14.1'
ONNXRT_PATH: 'onnxruntime==1.15.1'
COREML_PATH: NONE
numpy.version: ''
xgboost.version: '1.7.5'

Python310-1141-RT1140:
python.version: '3.10'
ONNX_PATH: 'onnx==1.14.1'
ONNXRT_PATH: onnxruntime==1.14.0
COREML_PATH: NONE
numpy.version: ''
xgboost.version: '1.7.5'

Python39-1141-RT1140:
python.version: '3.9'
ONNX_PATH: 'onnx==1.14.1'
ONNXRT_PATH: onnxruntime==1.14.0
COREML_PATH: NONE
numpy.version: ''
xgboost.version: '1.7.5'

maxParallel: 3

Expand All @@ -74,6 +79,8 @@ jobs:
- script: |
call activate py$(python.version)
python -m pip install --upgrade scikit-learn
python -m pip install --upgrade lightgbm
python -m pip install "xgboost==$(xgboost.version)"
displayName: 'Install scikit-learn'
- script: |
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOGS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## 1.12.0

* Fix discrepancies with XGBRegressor and xgboost > 2
[#670](https://github.com/onnx/onnxmltools/pull/670)
* Support count:poisson for XGBRegressor
[#666](https://github.com/onnx/onnxmltools/pull/666)
* Supports XGBRFClassifier and XGBRFRegressor
Expand Down
13 changes: 8 additions & 5 deletions onnxmltools/convert/xgboost/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,15 @@ def _get_attributes(booster):
ntrees = trees // num_class if num_class > 0 else trees
else:
trees = len(res)
ntrees = booster.best_ntree_limit
num_class = trees // ntrees
ntrees = getattr(booster, "best_ntree_limit", trees)
config = json.loads(booster.save_config())["learner"]["learner_model_param"]
num_class = int(config["num_class"]) if "num_class" in config else 0
if num_class == 0 and ntrees > 0:
num_class = trees // ntrees
if num_class == 0:
raise RuntimeError(
"Unable to retrieve the number of classes, trees=%d, ntrees=%d."
% (trees, ntrees)
f"Unable to retrieve the number of classes, num_class={num_class}, "
f"trees={trees}, ntrees={ntrees}, config={config}."
)

kwargs = atts.copy()
Expand Down Expand Up @@ -137,7 +140,7 @@ def __init__(self, booster):
self.operator_name = "XGBRegressor"

def get_xgb_params(self):
return self.kwargs
return {k: v for k, v in self.kwargs.items() if v is not None}

def get_booster(self):
return self.booster_
Expand Down
26 changes: 24 additions & 2 deletions onnxmltools/convert/xgboost/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
Common function to converters and shape calculators.
"""
import json


def get_xgb_params(xgb_node):
Expand All @@ -15,8 +16,29 @@ def get_xgb_params(xgb_node):
else:
# XGBoost < 0.7
params = xgb_node.__dict__

if hasattr("xgb_node", "save_config"):
config = json.loads(xgb_node.save_config())
else:
config = json.loads(xgb_node.get_booster().save_config())
params = {k: v for k, v in params.items() if v is not None}
num_class = int(config["learner"]["learner_model_param"]["num_class"])
if num_class > 0:
params["num_class"] = num_class
if "n_estimators" not in params and hasattr(xgb_node, "n_estimators"):
# xgboost >= 1.0.2
params["n_estimators"] = xgb_node.n_estimators
if xgb_node.n_estimators is not None:
params["n_estimators"] = xgb_node.n_estimators
if "base_score" in config["learner"]["learner_model_param"]:
bs = float(config["learner"]["learner_model_param"]["base_score"])
# xgboost >= 2.0
params["base_score"] = bs
return params


def get_n_estimators_classifier(xgb_node, params, js_trees):
if "n_estimators" not in params:
num_class = params.get("num_class", 0)
if num_class == 0:
return len(js_trees)
return len(js_trees) // num_class
return params["n_estimators"]
32 changes: 21 additions & 11 deletions onnxmltools/convert/xgboost/operator_converters/XGBoost.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
except ImportError:
XGBRFClassifier = None
from ...common._registration import register_converter
from ..common import get_xgb_params
from ..common import get_xgb_params, get_n_estimators_classifier


class XGBConverter:
Expand Down Expand Up @@ -161,8 +161,7 @@ def _fill_node_attributes(
false_child_id=remap[jsnode["no"]], # ['children'][1]['nodeid'],
weights=None,
weight_id_bias=None,
missing=jsnode.get("missing", -1)
== jsnode["yes"], # ['children'][0]['nodeid'],
missing=jsnode.get("missing", -1) == jsnode["yes"],
hitrate=jsnode.get("cover", 0),
)

Expand Down Expand Up @@ -265,8 +264,8 @@ def convert(scope, operator, container):
)

if objective == "count:poisson":
cst = scope.get_unique_variable_name("half")
container.add_initializer(cst, TensorProto.FLOAT, [1], [0.5])
cst = scope.get_unique_variable_name("poisson")
container.add_initializer(cst, TensorProto.FLOAT, [1], [base_score])
new_name = scope.get_unique_variable_name("exp")
container.add_node("Exp", names, [new_name])
container.add_node("Mul", [new_name, cst], operator.output_full_names)
Expand All @@ -293,11 +292,18 @@ def convert(scope, operator, container):
objective, base_score, js_trees = XGBConverter.common_members(xgb_node, inputs)

params = XGBConverter.get_xgb_params(xgb_node)
n_estimators = get_n_estimators_classifier(xgb_node, params, js_trees)
num_class = params.get("num_class", None)

attr_pairs = XGBClassifierConverter._get_default_tree_attribute_pairs()
XGBConverter.fill_tree_attributes(
js_trees, attr_pairs, [1 for _ in js_trees], True
)
ncl = (max(attr_pairs["class_treeids"]) + 1) // params["n_estimators"]
if num_class is not None:
ncl = num_class
n_estimators = len(js_trees) // ncl
else:
ncl = (max(attr_pairs["class_treeids"]) + 1) // n_estimators

bst = xgb_node.get_booster()
best_ntree_limit = getattr(bst, "best_ntree_limit", len(js_trees)) * ncl
Expand All @@ -310,15 +316,17 @@ def convert(scope, operator, container):

if len(attr_pairs["class_treeids"]) == 0:
raise RuntimeError("XGBoost model is empty.")

if ncl <= 1:
ncl = 2
if objective != "binary:hinge":
# See https://github.com/dmlc/xgboost/blob/master/src/common/math.h#L23.
attr_pairs["post_transform"] = "LOGISTIC"
attr_pairs["class_ids"] = [0 for v in attr_pairs["class_treeids"]]
if js_trees[0].get("leaf", None) == 0:
attr_pairs["base_values"] = [0.5]
attr_pairs["base_values"] = [base_score]
elif base_score != 0.5:
# 0.5 -> cst = 0
cst = -np.log(1 / np.float32(base_score) - 1.0)
attr_pairs["base_values"] = [cst]
else:
Expand All @@ -330,8 +338,10 @@ def convert(scope, operator, container):
attr_pairs["class_ids"] = [v % ncl for v in attr_pairs["class_treeids"]]

classes = xgb_node.classes_
if np.issubdtype(classes.dtype, np.floating) or np.issubdtype(
classes.dtype, np.integer
if (
np.issubdtype(classes.dtype, np.floating)
or np.issubdtype(classes.dtype, np.integer)
or np.issubdtype(classes.dtype, np.bool_)
):
attr_pairs["classlabels_int64s"] = classes.astype("int")
else:
Expand Down Expand Up @@ -373,7 +383,7 @@ def convert(scope, operator, container):
"Where", [greater, one, zero], operator.output_full_names[1]
)
elif objective in ("multi:softprob", "multi:softmax"):
ncl = len(js_trees) // params["n_estimators"]
ncl = len(js_trees) // n_estimators
if objective == "multi:softmax":
attr_pairs["post_transform"] = "NONE"
container.add_node(
Expand All @@ -385,7 +395,7 @@ def convert(scope, operator, container):
**attr_pairs,
)
elif objective == "reg:logistic":
ncl = len(js_trees) // params["n_estimators"]
ncl = len(js_trees) // n_estimators
if ncl == 1:
ncl = 2
container.add_node(
Expand Down
20 changes: 14 additions & 6 deletions onnxmltools/convert/xgboost/shape_calculators/Classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Int64TensorType,
StringTensorType,
)
from ..common import get_xgb_params
from ..common import get_xgb_params, get_n_estimators_classifier


def calculate_xgboost_classifier_output_shapes(operator):
Expand All @@ -22,18 +22,26 @@ def calculate_xgboost_classifier_output_shapes(operator):
params = get_xgb_params(xgb_node)
booster = xgb_node.get_booster()
booster.attributes()
ntrees = len(booster.get_dump(with_stats=True, dump_format="json"))
js_trees = booster.get_dump(with_stats=True, dump_format="json")
ntrees = len(js_trees)
objective = params["objective"]
n_estimators = get_n_estimators_classifier(xgb_node, params, js_trees)
num_class = params.get("num_class", None)

if objective == "binary:logistic":
if num_class is not None:
ncl = num_class
n_estimators = ntrees // ncl
elif objective == "binary:logistic":
ncl = 2
else:
ncl = ntrees // params["n_estimators"]
ncl = ntrees // n_estimators
if objective == "reg:logistic" and ncl == 1:
ncl = 2
classes = xgb_node.classes_
if np.issubdtype(classes.dtype, np.floating) or np.issubdtype(
classes.dtype, np.integer
if (
np.issubdtype(classes.dtype, np.floating)
or np.issubdtype(classes.dtype, np.integer)
or np.issubdtype(classes.dtype, np.bool_)
):
operator.outputs[0].type = Int64TensorType(shape=[N])
else:
Expand Down
2 changes: 2 additions & 0 deletions onnxmltools/utils/utils_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def compare_outputs(expected, output, **kwargs):
Disc = kwargs.pop("Disc", False)
Mism = kwargs.pop("Mism", False)
Opp = kwargs.pop("Opp", False)
if hasattr(expected, "dtype") and expected.dtype == numpy.bool_:
expected = expected.astype(numpy.int64)
if Opp and not NoProb:
raise ValueError("Opp is only available if NoProb is True")

Expand Down
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pytest-spark
ruff
scikit-learn>=1.2.0
scipy
skl2onnx
wheel
xgboost==1.7.5
xgboost
onnxruntime
5 changes: 2 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
numpy
onnx
skl2onnx
numpy
onnx
8 changes: 4 additions & 4 deletions tests/baseline/test_convert_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def get_diff(self, input_file, ref_file):
def normalize_diff(self, diff):
invalid_comparisons = []
invalid_comparisons.append(
re.compile('producer_version: "\d+\.\d+\.\d+\.\d+.*')
re.compile('producer_version: "\\d+\\.\\d+\\.\\d+\\.\\d+.*')
)
invalid_comparisons.append(re.compile('\s+name: ".*'))
invalid_comparisons.append(re.compile("ir_version: \d+"))
invalid_comparisons.append(re.compile("\s+"))
invalid_comparisons.append(re.compile('\\s+name: ".*'))
invalid_comparisons.append(re.compile("ir_version: \\d+"))
invalid_comparisons.append(re.compile("\\s+"))
valid_diff = set()
for line in diff:
if any(comparison.match(line) for comparison in invalid_comparisons):
Expand Down
1 change: 0 additions & 1 deletion tests/h2o/test_h2o_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def test_h2o_classifier_multi_cat(self):
train, test = _prepare_one_hot("airlines.csv", y)
gbm = H2OGradientBoostingEstimator(ntrees=8, max_depth=5)
mojo_path = _make_mojo(gbm, train, y=train.columns.index(y))
print("****", mojo_path)
onnx_model = _convert_mojo(mojo_path)
self.assertIsNot(onnx_model, None)
dump_data_and_model(
Expand Down
Loading

0 comments on commit 7858f9f

Please sign in to comment.