diff --git a/.azure-pipelines/linux-conda-CI.yml b/.azure-pipelines/linux-conda-CI.yml index 94547e89c..c11a89daf 100644 --- a/.azure-pipelines/linux-conda-CI.yml +++ b/.azure-pipelines/linux-conda-CI.yml @@ -14,53 +14,68 @@ jobs: vmImage: 'ubuntu-latest' strategy: matrix: + Python310-1130-RT1131-xgb173: + python.version: '3.10' + ONNX_PATH: 'onnx==1.13.0' #'-i https://test.pypi.org/simple/ onnx==1.12.0rc4' + ONNXRT_PATH: onnxruntime==1.13.1 #'-i https://test.pypi.org/simple/ ort-nightly==1.11.0.dev20220311003' + COREML_PATH: NONE + xgboost.version: '>=1.7.3' + numpy.version: '' Python310-1120-RT1121-xgb161: - python.version: '3.9' + python.version: '3.10' ONNX_PATH: 'onnx==1.12.0' #'-i https://test.pypi.org/simple/ onnx==1.12.0rc4' ONNXRT_PATH: onnxruntime==1.12.1 #'-i https://test.pypi.org/simple/ ort-nightly==1.11.0.dev20220311003' - COREML_PATH: git+https://github.com/apple/coremltools@3.1 - xgboost.version: '>=1.6.1' + COREML_PATH: NONE + xgboost.version: '==1.6.1' + numpy.version: '' Python39-1120-RT1110-xgb161: python.version: '3.9' ONNX_PATH: 'onnx==1.12.0' #'-i https://test.pypi.org/simple/ onnx==1.12.0rc4' ONNXRT_PATH: onnxruntime==1.11.0 #'-i https://test.pypi.org/simple/ ort-nightly==1.11.0.dev20220311003' COREML_PATH: git+https://github.com/apple/coremltools@3.1 xgboost.version: '>=1.6.1' + numpy.version: '' Python39-1120-RT1110-xgb142: python.version: '3.9' ONNX_PATH: 'onnx==1.12.0' #'-i https://test.pypi.org/simple/ onnx==1.12.0rc4' ONNXRT_PATH: onnxruntime==1.11.0 #'-i https://test.pypi.org/simple/ ort-nightly==1.11.0.dev20220311003' COREML_PATH: git+https://github.com/apple/coremltools@3.1 xgboost.version: '==1.4.2' + numpy.version: '' Python39-1110-RT1110: python.version: '3.9' ONNX_PATH: onnx==1.11.0 # '-i https://test.pypi.org/simple/ onnx==1.9.101' ONNXRT_PATH: onnxruntime==1.11.0 #'-i https://test.pypi.org/simple/ ort-nightly==1.11.0.dev20220311003' COREML_PATH: git+https://github.com/apple/coremltools@3.1 + numpy.version: '' Python39-1110-RT1100: python.version: '3.9' ONNX_PATH: onnx==1.11.0 # '-i https://test.pypi.org/simple/ onnx==1.9.101' ONNXRT_PATH: onnxruntime==1.10.0 COREML_PATH: git+https://github.com/apple/coremltools@3.1 xgboost.version: '>=1.2' + numpy.version: '' Python39-1101-RT190: python.version: '3.9' ONNX_PATH: onnx==1.10.1 ONNXRT_PATH: onnxruntime==1.9.0 COREML_PATH: git+https://github.com/apple/coremltools@3.1 xgboost.version: '>=1.2' + numpy.version: '<=1.23.5' Python39-190-RT180-xgb11: python.version: '3.9' ONNX_PATH: onnx==1.9.0 ONNXRT_PATH: onnxruntime==1.8.0 COREML_PATH: git+https://github.com/apple/coremltools@3.1 xgboost.version: '>=1.2' + numpy.version: '<=1.23.5' Python38-181-RT170-xgb11: python.version: '3.8' ONNX_PATH: onnx==1.8.1 ONNXRT_PATH: onnxruntime==1.7.0 COREML_PATH: git+https://github.com/apple/coremltools@3.1 xgboost.version: '>=1.2' + numpy.version: '<=1.23.5' maxParallel: 3 @@ -85,7 +100,7 @@ jobs: pip install xgboost$(xgboost.version) pip install $(ONNX_PATH) pip install $(ONNXRT_PATH) - pip install $(COREML_PATH) + pip install "numpy$(numpy.version)" displayName: 'Install xgboost, onnxruntime' - script: | @@ -149,7 +164,7 @@ jobs: displayName: 'pytest - h2o' - script: | - pip install torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html + pip install torch --extra-index-url https://download.pytorch.org/whl/cpu pip install hummingbird-ml --no-deps displayName: 'Install hummingbird-ml' @@ -159,9 +174,16 @@ jobs: displayName: 'pytest - hummingbirdml' - script: | - export PYTHONPATH=. - pytest tests/coreml --durations=0 - displayName: 'pytest - coreml' + if [ '$(COREML_PATH)' == 'NONE' ] + then + echo "required version of coremltools does not work on python 3.10" + else + export PYTHONPATH=. + pip install $(COREML_PATH) + pytest tests/coreml --durations=0 + fi + displayName: 'pytest - coreml [$(COREML_PATH)]' + # condition: ne('$(COREML_PATH)', 'NONE') - task: PublishTestResults@2 inputs: diff --git a/.azure-pipelines/win32-conda-CI.yml b/.azure-pipelines/win32-conda-CI.yml index 128e77891..2d44629b7 100644 --- a/.azure-pipelines/win32-conda-CI.yml +++ b/.azure-pipelines/win32-conda-CI.yml @@ -15,47 +15,61 @@ jobs: strategy: matrix: + Python310-1130-RT1131: + python.version: '3.10' + ONNX_PATH: 'onnx==1.13.0' # '-i https://test.pypi.org/simple/ onnx==1.12.0rc4' + ONNXRT_PATH: onnxruntime==1.13.1 #'-i https://test.pypi.org/simple/ ort-nightly==1.11.0.dev20220311003' + COREML_PATH: NONE + numpy.version: '' + Python39-1120-RT1110: python.version: '3.9' ONNX_PATH: 'onnx==1.12.0' # '-i https://test.pypi.org/simple/ onnx==1.12.0rc4' ONNXRT_PATH: onnxruntime==1.11.0 #'-i https://test.pypi.org/simple/ ort-nightly==1.11.0.dev20220311003' COREML_PATH: git+https://github.com/apple/coremltools@3.1 + numpy.version: '' Python39-1110-RT1110: python.version: '3.9' ONNX_PATH: onnx==1.11.0 # '-i https://test.pypi.org/simple/ onnx==1.9.101' ONNXRT_PATH: onnxruntime==1.11.0 #'-i https://test.pypi.org/simple/ ort-nightly==1.11.0.dev20220311003' COREML_PATH: git+https://github.com/apple/coremltools@3.1 + numpy.version: '' Python39-1110-RT190: python.version: '3.9' ONNX_PATH: 'onnx==1.11.0' # '-i https://test.pypi.org/simple/ onnx==1.9.101' ONNXRT_PATH: onnxruntime==1.10.0 COREML_PATH: git+https://github.com/apple/coremltools@3.1 + numpy.version: '' Python39-1102-RT190: python.version: '3.9' ONNX_PATH: 'onnx==1.10.2' # '-i https://test.pypi.org/simple/ onnx==1.9.101' ONNXRT_PATH: onnxruntime==1.9.0 COREML_PATH: git+https://github.com/apple/coremltools@3.1 + numpy.version: '<=1.23.5' Python39-190-RT181: python.version: '3.9' ONNX_PATH: 'onnx==1.9.0' ONNXRT_PATH: onnxruntime==1.8.1 COREML_PATH: git+https://github.com/apple/coremltools@3.1 + numpy.version: '<=1.23.5' Python39-190-RT180: python.version: '3.9' ONNX_PATH: onnx==1.9.0 ONNXRT_PATH: onnxruntime==1.8.0 COREML_PATH: git+https://github.com/apple/coremltools@3.1 + numpy.version: '<=1.23.5' Python38-181-RT170: python.version: '3.8' ONNX_PATH: onnx==1.8.1 ONNXRT_PATH: onnxruntime==1.7.0 COREML_PATH: git+https://github.com/apple/coremltools@3.1 + numpy.version: '<=1.23.5' maxParallel: 3 @@ -83,15 +97,11 @@ jobs: python -m pip install --upgrade scikit-learn displayName: 'Install scikit-learn' - - script: | - call activate py$(python.version) - python -m pip install %COREML_PATH% - displayName: 'Install coremltools' - - script: | call activate py$(python.version) python -m pip install %ONNX_PATH% python -m pip install %ONNXRT_PATH% + python -m pip install "numpy$(numpy.version)" displayName: 'Install onnxruntime' - script: | @@ -124,9 +134,11 @@ jobs: - script: | call activate py$(python.version) - export PYTHONPATH=. - python -m pytest tests/coreml --durations=0 - displayName: 'pytest coreml' + set PYTHONPATH=. + if "$(COREML_PATH)" neq "NONE" python -m pip install %COREML_PATH% + if "$(COREML_PATH)" neq "NONE" python -m pytest tests/coreml --durations=0 + displayName: 'pytest coreml - [$(COREML_PATH)]' + #condition: ne('$(COREML_PATH)', 'NONE') - script: | call activate py$(python.version) @@ -160,7 +172,7 @@ jobs: - script: | call activate py$(python.version) - python -m pip install torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio===0.8.1 -f https://download.pytorch.org/whl/torch_stable.html + python -m pip install torch python -m pip install hummingbird-ml --no-deps displayName: 'Install hummingbird-ml' diff --git a/onnxmltools/convert/sparkml/operator_converters/aft_survival_regression.py b/onnxmltools/convert/sparkml/operator_converters/aft_survival_regression.py index 5ae267ecc..c1e4512d4 100644 --- a/onnxmltools/convert/sparkml/operator_converters/aft_survival_regression.py +++ b/onnxmltools/convert/sparkml/operator_converters/aft_survival_regression.py @@ -15,7 +15,10 @@ def convert_aft_survival_regression(scope, operator, container): coefficients = op.coefficients.toArray().astype(float) coefficients_tensor = scope.get_unique_variable_name('coefficients_tensor') container.add_initializer(coefficients_tensor, onnx_proto.TensorProto.FLOAT, [1, len(coefficients)], coefficients) - intercepts = op.intercept.astype(float) if isinstance(op.intercept, collections.Iterable) else [float(op.intercept)] + intercepts = ( + op.intercept.astype(float) + if isinstance(op.intercept, collections.abc.Iterable) + else [float(op.intercept)]) intercepts_tensor = scope.get_unique_variable_name('intercepts_tensor') container.add_initializer(intercepts_tensor, onnx_proto.TensorProto.FLOAT, [len(intercepts)], intercepts) diff --git a/onnxmltools/convert/sparkml/operator_converters/linear_regressor.py b/onnxmltools/convert/sparkml/operator_converters/linear_regressor.py index d5b65e597..bf09b9aae 100644 --- a/onnxmltools/convert/sparkml/operator_converters/linear_regressor.py +++ b/onnxmltools/convert/sparkml/operator_converters/linear_regressor.py @@ -13,8 +13,10 @@ def convert_sparkml_linear_regressor(scope, operator, container): attrs = { 'name': scope.get_unique_operator_name(op_type), 'coefficients': op.coefficients.astype(float), - 'intercepts': op.intercept.astype(float) if isinstance(op.intercept, collections.Iterable) else [ - float(op.intercept)] + 'intercepts': ( + op.intercept.astype(float) + if isinstance(op.intercept, collections.abc.Iterable) + else [float(op.intercept)]) } container.add_node(op_type, operator.input_full_names, operator.output_full_names, op_domain='ai.onnx.ml', **attrs) diff --git a/onnxmltools/convert/xgboost/convert.py b/onnxmltools/convert/xgboost/convert.py index 0592550eb..debc05c2d 100644 --- a/onnxmltools/convert/xgboost/convert.py +++ b/onnxmltools/convert/xgboost/convert.py @@ -41,4 +41,13 @@ def convert(model, name=None, initial_types=None, doc_string='', target_opset=No topology = parse_xgboost(model, initial_types, target_opset, custom_conversion_functions, custom_shape_calculators) topology.compile() onnx_model = convert_topology(topology, name, doc_string, target_opset, targeted_onnx) + opsets = {d.domain: d.version for d in onnx_model.opset_import} + if '' in opsets and opsets[''] < 9 and target_opset >= 9: + # a bug? + opsets[''] = 9 + del onnx_model.opset_import[:] + for k, v in opsets.items(): + opset = onnx_model.opset_import.add() + opset.domain = k + opset.version = v return onnx_model diff --git a/onnxmltools/convert/xgboost/operator_converters/XGBoost.py b/onnxmltools/convert/xgboost/operator_converters/XGBoost.py index eebe9b2fb..4ae246abd 100644 --- a/onnxmltools/convert/xgboost/operator_converters/XGBoost.py +++ b/onnxmltools/convert/xgboost/operator_converters/XGBoost.py @@ -2,6 +2,7 @@ import json import numpy as np +from onnx import TensorProto from xgboost import XGBClassifier from ...common._registration import register_converter from ..common import get_xgb_params @@ -241,14 +242,17 @@ def convert(scope, operator, container): raise RuntimeError("XGBoost model is empty.") if ncl <= 1: ncl = 2 - # See https://github.com/dmlc/xgboost/blob/master/src/common/math.h#L23. - attr_pairs['post_transform'] = "LOGISTIC" - attr_pairs['class_ids'] = [0 for v in attr_pairs['class_treeids']] - if js_trees[0].get('leaf', None) == 0: - attr_pairs['base_values'] = [0.5] - elif base_score != 0.5: - cst = - np.log(1 / np.float32(base_score) - 1.) - attr_pairs['base_values'] = [cst] + if objective != 'binary:hinge': + # See https://github.com/dmlc/xgboost/blob/master/src/common/math.h#L23. + attr_pairs['post_transform'] = "LOGISTIC" + attr_pairs['class_ids'] = [0 for v in attr_pairs['class_treeids']] + if js_trees[0].get('leaf', None) == 0: + attr_pairs['base_values'] = [0.5] + elif base_score != 0.5: + cst = - np.log(1 / np.float32(base_score) - 1.) + attr_pairs['base_values'] = [cst] + else: + attr_pairs['base_values'] = [base_score] else: # See https://github.com/dmlc/xgboost/blob/master/src/common/math.h#L35. attr_pairs['post_transform'] = "SOFTMAX" @@ -264,13 +268,33 @@ def convert(scope, operator, container): attr_pairs['classlabels_strings'] = classes # add nodes - if objective == "binary:logistic": + if objective in ("binary:logistic", "binary:hinge"): ncl = 2 - container.add_node('TreeEnsembleClassifier', operator.input_full_names, - operator.output_full_names, + if objective == "binary:hinge": + attr_pairs['post_transform'] = 'NONE' + output_names = [operator.output_full_names[0], + scope.get_unique_variable_name("output_prob")] + else: + output_names = operator.output_full_names + container.add_node('TreeEnsembleClassifier', + operator.input_full_names, + output_names, op_domain='ai.onnx.ml', name=scope.get_unique_operator_name('TreeEnsembleClassifier'), **attr_pairs) + if objective == "binary:hinge": + if container.target_opset < 9: + raise RuntimeError( + f"hinge function cannot be implemented because " + f"opset={container.target_opset}<9.") + zero = scope.get_unique_variable_name("zero") + one = scope.get_unique_variable_name("one") + container.add_initializer(zero, TensorProto.FLOAT, [1], [0.]) + container.add_initializer(one, TensorProto.FLOAT, [1], [1.]) + greater = scope.get_unique_variable_name("output_prob") + container.add_node("Greater", [output_names[1], zero], [greater]) + container.add_node('Where', [greater, one, zero], + operator.output_full_names[1]) elif objective in ("multi:softprob", "multi:softmax"): ncl = len(js_trees) // params['n_estimators'] if objective == 'multi:softmax': diff --git a/onnxmltools/proto/__init__.py b/onnxmltools/proto/__init__.py index 009252baf..ff24c446a 100644 --- a/onnxmltools/proto/__init__.py +++ b/onnxmltools/proto/__init__.py @@ -21,7 +21,7 @@ def _check_onnx_version(): _check_onnx_version() -def _make_tensor_fixed(name, data_type, dims, vals, raw=False): +def make_tensor_fixed(name, data_type, dims, vals, raw=False): ''' Make a TensorProto with specified arguments. If raw is False, this function will choose the corresponding proto field to store the @@ -45,6 +45,3 @@ def _make_tensor_fixed(name, data_type, dims, vals, raw=False): tensor.dims.extend(dims) return tensor - - -helper.make_tensor = _make_tensor_fixed diff --git a/onnxmltools/utils/utils_backend_onnxruntime.py b/onnxmltools/utils/utils_backend_onnxruntime.py index 746220690..bb755a780 100644 --- a/onnxmltools/utils/utils_backend_onnxruntime.py +++ b/onnxmltools/utils/utils_backend_onnxruntime.py @@ -52,7 +52,7 @@ def compare_runtime(test, decimal=5, options=None, verbose=False, context=None): return try: - sess = onnxruntime.InferenceSession(onx) + sess = onnxruntime.InferenceSession(onx, providers=["CPUExecutionProvider"]) except ExpectedAssertionError as expe: raise expe except Exception as e: @@ -312,7 +312,7 @@ def run_with_runtime(inputs, model_path): ''' try: import onnxruntime - session = onnxruntime.InferenceSession(model_path) + session = onnxruntime.InferenceSession(model_path, providers=["CPUExecutionProvider"]) output = session.run(None, inputs) return (output, session) except Exception as e: diff --git a/tests/baseline/test_convert_baseline.py b/tests/baseline/test_convert_baseline.py index 962e85298..ddabacacf 100644 --- a/tests/baseline/test_convert_baseline.py +++ b/tests/baseline/test_convert_baseline.py @@ -9,7 +9,10 @@ from onnx.defs import onnx_opset_version from onnxmltools.convert import convert_coreml from onnxconverter_common.onnx_ex import DEFAULT_OPSET_NUMBER -import coremltools +try: + import coremltools +except ImportError: + coremltools = None TARGET_OPSET = min(DEFAULT_OPSET_NUMBER, onnx_opset_version()) @@ -24,6 +27,8 @@ def check_baseline(self, input_file, ref_file): def get_diff(self, input_file, ref_file): this = os.path.dirname(__file__) coreml_file = os.path.join(this, "models", input_file) + if coremltools is None: + return [] cml = coremltools.utils.load_spec(coreml_file) onnx_model = convert_coreml(cml, target_opset=TARGET_OPSET) output_dir = os.path.join(this, "outmodels") diff --git a/tests/h2o/test_h2o_converters.py b/tests/h2o/test_h2o_converters.py index 5c3825ae8..a4d8a2145 100644 --- a/tests/h2o/test_h2o_converters.py +++ b/tests/h2o/test_h2o_converters.py @@ -91,7 +91,7 @@ def _prepare_one_hot(file, y, exclude_cols=None): def _train_test_split_as_frames(x, y, is_str=False, is_classifier=False): - y = y.astype(np.str) if is_str else y.astype(np.int64) + y = y.astype(np.str_) if is_str else y.astype(np.int64) x_train, x_test, y_train, _ = train_test_split(x, y, test_size=0.3, random_state=42) f_train_x = H2OFrame(x_train) f_train_y = H2OFrame(y_train) @@ -138,7 +138,7 @@ def predict_with_probabilities(self, data): return [preds.to_numpy()] else: return [ - preds.iloc[:, 0].to_numpy().astype(np.str), + preds.iloc[:, 0].to_numpy().astype(np.str_), preds.iloc[:, 1:].to_numpy() ] diff --git a/tests/xgboost/test_xgboost_13.py b/tests/xgboost/test_xgboost_13.py index ef599aefc..ac877a576 100644 --- a/tests/xgboost/test_xgboost_13.py +++ b/tests/xgboost/test_xgboost_13.py @@ -41,7 +41,7 @@ def test_xgb_regressor(self): initial_type = [('float_input', FloatTensorType([None, 797]))] onx = convert_xgboost(clr, initial_types=initial_type, target_opset=TARGET_OPSET) expected = clr.predict(X_test), clr.predict_proba(X_test) - sess = InferenceSession(onx.SerializeToString()) + sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"]) X_test = X_test.values.astype(np.float32) got = sess.run(None, {'float_input': X_test}) assert_almost_equal(expected[1], got[1]) diff --git a/tests/xgboost/test_xgboost_converters.py b/tests/xgboost/test_xgboost_converters.py index 21e40eba5..9845c1e0d 100644 --- a/tests/xgboost/test_xgboost_converters.py +++ b/tests/xgboost/test_xgboost_converters.py @@ -9,7 +9,7 @@ from numpy.testing import assert_almost_equal import pandas from sklearn.datasets import ( - load_diabetes, load_iris, make_classification, load_digits) + load_diabetes, load_iris, make_classification, load_digits, make_regression) from sklearn.model_selection import train_test_split from xgboost import XGBRegressor, XGBClassifier, train, DMatrix, Booster, train as train_xgb from sklearn.preprocessing import StandardScaler @@ -24,11 +24,25 @@ TARGET_OPSET = min(DEFAULT_OPSET_NUMBER, onnx_opset_version()) +def fct_cl2(y): + y[y == 2] = 0 + return y + + +def fct_cl3(y): + y[y == 0] = 6 + return y + + +def fct_id(y): + return y + + def _fit_classification_model(model, n_classes, is_str=False, dtype=None): x, y = make_classification(n_classes=n_classes, n_features=100, n_samples=1000, random_state=42, n_informative=7) - y = y.astype(np.str) if is_str else y.astype(np.int64) + y = y.astype(np.str_) if is_str else y.astype(np.int64) x_train, x_test, y_train, _ = train_test_split(x, y, test_size=0.5, random_state=42) if dtype is not None: @@ -235,7 +249,7 @@ def test_xgboost_classifier_i5450_softmax(self): clr.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=40) initial_type = [('float_input', FloatTensorType([None, 4]))] onx = convert_xgboost(clr, initial_types=initial_type, target_opset=TARGET_OPSET) - sess = InferenceSession(onx.SerializeToString()) + sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"]) input_name = sess.get_inputs()[0].name label_name = sess.get_outputs()[1].name predict_list = [1., 20., 466., 0.] @@ -255,7 +269,7 @@ def test_xgboost_classifier_i5450(self): clr.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=40) initial_type = [('float_input', FloatTensorType([None, 4]))] onx = convert_xgboost(clr, initial_types=initial_type, target_opset=TARGET_OPSET) - sess = InferenceSession(onx.SerializeToString()) + sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"]) input_name = sess.get_inputs()[0].name label_name = sess.get_outputs()[1].name predict_list = [1., 20., 466., 0.] @@ -304,7 +318,7 @@ def test_xgb_empty_tree(self): xgb, initial_types=[ ('input', FloatTensorType(shape=[None, X.shape[1]]))], target_opset=TARGET_OPSET) - sess = InferenceSession(conv_model.SerializeToString()) + sess = InferenceSession(conv_model.SerializeToString(), providers=["CPUExecutionProvider"]) res = sess.run(None, {'input': X.astype(np.float32)}) assert_almost_equal(xgb.predict_proba(X), res[1]) assert_almost_equal(xgb.predict(X), res[0]) @@ -325,7 +339,7 @@ def test_xgb_best_tree_limit(self): onx_loaded = convert_xgboost( bst_original, initial_types=initial_type, target_opset=TARGET_OPSET) - sess = InferenceSession(onx_loaded.SerializeToString()) + sess = InferenceSession(onx_loaded.SerializeToString(), providers=["CPUExecutionProvider"]) res = sess.run(None, {'float_input': X_test.astype(np.float32)}) assert_almost_equal(bst_original.predict(dtest, output_margin=True), res[1], decimal=5) assert_almost_equal(bst_original.predict(dtest), res[0]) @@ -342,7 +356,7 @@ def test_xgb_best_tree_limit(self): onx_loaded = convert_xgboost( bst_loaded, initial_types=initial_type, target_opset=TARGET_OPSET) - sess = InferenceSession(onx_loaded.SerializeToString()) + sess = InferenceSession(onx_loaded.SerializeToString(), providers=["CPUExecutionProvider"]) res = sess.run(None, {'float_input': X_test.astype(np.float32)}) assert_almost_equal(bst_loaded.predict(dtest, output_margin=True), res[1], decimal=5) assert_almost_equal(bst_loaded.predict(dtest), res[0]) @@ -364,11 +378,119 @@ def test_onnxrt_python_xgbclassifier(self): model_skl, initial_types=[('X', FloatTensorType([None, x.shape[1]]))], target_opset=TARGET_OPSET) with self.subTest(base_score=bm, n_estimators=n_est): - oinf = InferenceSession(model_onnx_skl.SerializeToString()) + oinf = InferenceSession(model_onnx_skl.SerializeToString(), providers=["CPUExecutionProvider"]) res2 = oinf.run(None, {'X': x_test}) assert_almost_equal(model_skl.predict_proba(x_test), res2[1]) + def test_xgb_cost(self): + obj_classes = { + 'reg:logistic': (XGBClassifier, fct_cl2, + make_classification(n_features=4, n_classes=2, + n_clusters_per_class=1)), + 'binary:logistic': (XGBClassifier, fct_cl2, + make_classification(n_features=4, n_classes=2, + n_clusters_per_class=1)), + 'multi:softmax': (XGBClassifier, fct_id, + make_classification(n_features=4, n_classes=3, + n_clusters_per_class=1)), + 'multi:softprob': (XGBClassifier, fct_id, + make_classification(n_features=4, n_classes=3, + n_clusters_per_class=1)), + 'reg:squarederror': (XGBRegressor, fct_id, + make_regression(n_features=4, n_targets=1)), + 'reg:squarederror2': (XGBRegressor, fct_id, + make_regression(n_features=4, n_targets=2)), + } + nb_tests = 0 + for objective in obj_classes: # pylint: disable=C0206 + for n_estimators in [1, 2]: + with self.subTest(objective=objective, n_estimators=n_estimators): + probs = [] + cl, fct, prob = obj_classes[objective] + + iris = load_iris() + X, y = iris.data, iris.target + y = fct(y) + X_train, X_test, y_train, _ = train_test_split( + X, y, random_state=11) + probs.append((X_train, X_test, y_train)) + + X_train, X_test, y_train, _ = train_test_split( + *prob, random_state=11) + probs.append((X_train, X_test, y_train)) + + for X_train, X_test, y_train in probs: + obj = objective.replace( + 'reg:squarederror2', 'reg:squarederror') + obj = obj.replace( + 'multi:softmax2', 'multi:softmax') + clr = cl(objective=obj, n_estimators=n_estimators) + if len(y_train.shape) == 2: + y_train = y_train[:, 1] + try: + clr.fit(X_train, y_train) + except ValueError as e: + raise AssertionError( + "Unable to train with objective %r and data %r." % ( + objective, y_train)) from e + + model_def = convert_xgboost( + clr, initial_types=[('X', FloatTensorType([None, X.shape[1]]))], + target_opset=TARGET_OPSET) + + oinf = InferenceSession(model_def.SerializeToString(), + providers=["CPUExecutionProvider"]) + y = oinf.run(None, {'X': X_test.astype(np.float32)}) + if cl == XGBRegressor: + exp = clr.predict(X_test) + assert_almost_equal(exp, y[0].ravel(), decimal=5) + else: + if 'softmax' not in obj: + exp = clr.predict_proba(X_test) + got = pandas.DataFrame(y[1]).values + assert_almost_equal(exp, got, decimal=5) + + exp = clr.predict(X_test[:10]) + assert_almost_equal(exp, y[0][:10]) + + nb_tests += 1 + + self.assertGreater(nb_tests, 8) + + def test_xgb_classifier_601(self): + model = XGBClassifier( + base_score=0.5, booster='gbtree', colsample_bylevel=1, + colsample_bynode=1, colsample_bytree=1, gamma=0, + importance_type='gain', interaction_constraints='', + learning_rate=0.3, max_delta_step=0, max_depth=6, + min_child_weight=1, missing=np.nan, + n_estimators=3, n_jobs=0, num_parallel_tree=1, + objective='multi:softprob', random_state=0, reg_alpha=0, + reg_lambda=1, scale_pos_weight=None, subsample=1, + tree_method='exact', validate_parameters=1) + xgb, x_test = _fit_classification_model(model, 3) + conv_model = convert_xgboost( + xgb, initial_types=[('input', FloatTensorType(shape=[None, None]))], + target_opset=TARGET_OPSET) + self.assertTrue(conv_model is not None) + dump_data_and_model( + x_test, xgb, conv_model, + basename="SklearnXGBClassifier601") + + def test_xgb_classifier_hinge(self): + model = XGBClassifier( + n_estimators=3, objective='binary:hinge', random_state=0, + max_depth=2) + xgb, x_test = _fit_classification_model(model, 2) + conv_model = convert_xgboost( + xgb, initial_types=[('input', FloatTensorType(shape=[None, None]))], + target_opset=TARGET_OPSET) + dump_data_and_model( + x_test, xgb, conv_model, + basename="SklearnXGBClassifierHinge") + + if __name__ == "__main__": # TestXGBoostModels().test_xgboost_booster_classifier_multiclass_softprob() - unittest.main() + unittest.main(verbosity=2) diff --git a/tests/xgboost/test_xgboost_pipeline.py b/tests/xgboost/test_xgboost_pipeline.py index 64bd81a04..4a3d62d7f 100644 --- a/tests/xgboost/test_xgboost_pipeline.py +++ b/tests/xgboost/test_xgboost_pipeline.py @@ -136,7 +136,7 @@ def common_test_xgboost_10_skl(self, missing, replace=False): onnx_last = convert_sklearn(model.steps[1][-1], initial_types=[('X', FloatTensorType(shape=[None, input_xgb.shape[1]]))], target_opset={'': TARGET_OPSET, 'ai.onnx.ml': TARGET_OPSET_ML}) - session = rt.InferenceSession(onnx_last.SerializeToString()) + session = rt.InferenceSession(onnx_last.SerializeToString(), providers=["CPUExecutionProvider"]) pred_skl = model.steps[1][-1].predict(input_xgb).ravel() pred_onx = session.run(None, {'X': input_xgb})[0].ravel() assert_almost_equal(pred_skl, pred_onx)