diff --git a/.azure-pipelines/linux-CI-nightly.yml b/.azure-pipelines/linux-CI-nightly.yml index af09d8d2..0135a2fd 100644 --- a/.azure-pipelines/linux-CI-nightly.yml +++ b/.azure-pipelines/linux-CI-nightly.yml @@ -47,7 +47,6 @@ jobs: python -m pip install $(ONNX_PATH) python -m pip install hummingbird-ml --no-deps python -m pip install -r requirements.txt - pip install torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html python -m pip install -r requirements-dev.txt python -m pip install $(ORT_PATH) python -m pip install pytest diff --git a/.azure-pipelines/linux-conda-CI.yml b/.azure-pipelines/linux-conda-CI.yml index 1df47749..ba70511b 100644 --- a/.azure-pipelines/linux-conda-CI.yml +++ b/.azure-pipelines/linux-conda-CI.yml @@ -9,6 +9,7 @@ trigger: jobs: - job: 'Test' + timeoutInMinutes: 25 pool: vmImage: 'ubuntu-latest' strategy: @@ -70,31 +71,28 @@ jobs: maxParallel: 3 steps: - - script: sudo install -d -m 0777 /home/vsts/.conda/envs - displayName: Fix Conda permissions - - - task: CondaEnvironment@1 + - task: UsePythonVersion@0 inputs: - createCustomEnvironment: true - environmentName: 'py$(python.version)' - packageSpecs: 'python=$(python.version)' + versionSpec: '$(python.version)' + architecture: 'x64' - script: | python -m pip install --upgrade pip - conda config --set always_yes yes --set changeps1 no - conda install -c conda-forge protobuf - conda install -c conda-forge numpy - conda install -c conda-forge cmake - pip install $(COREML_PATH) - pip install $(ONNX_PATH) - pip install hummingbird-ml --no-deps + pip install $(ONNX_PATH) $(ONNXRT_PATH) cython pip install -r requirements.txt - pip install torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html + displayName: 'Install dependencies' + + - script: | pip install -r requirements-dev.txt + displayName: 'Install dependencies-dev' + + - script: | + python -m pip install --upgrade pip pip install xgboost$(xgboost.version) + pip install $(ONNX_PATH) pip install $(ONNXRT_PATH) - pip install pytest - displayName: 'Install dependencies' + pip install $(COREML_PATH) + displayName: 'Install xgboost, onnxruntime' - script: | pip install flake8 @@ -109,8 +107,63 @@ jobs: export PYTHONPATH=. python -c "import onnxconverter_common;print(onnxconverter_common.__version__)" python -c "import onnxruntime;print(onnxruntime.__version__)" - pytest tests --doctest-modules --junitxml=junit/test-results.xml - displayName: 'pytest - onnxmltools' + displayName: 'version' + + - script: | + export PYTHONPATH=. + pytest tests/baseline --durations=0 + displayName: 'pytest - baseline' + + - script: | + export PYTHONPATH=. + pytest tests/catboost --durations=0 + displayName: 'pytest - catboost' + + - script: | + export PYTHONPATH=. + pytest tests/coreml --durations=0 + displayName: 'pytest - coreml' + + - script: | + export PYTHONPATH=. + pytest tests/lightgbm --durations=0 + displayName: 'pytest - lightgbm' + + - script: | + export PYTHONPATH=. + pytest tests/sparkml --durations=0 + displayName: 'pytest - sparkml' + + - script: | + export PYTHONPATH=. + pytest tests/svmlib --durations=0 + displayName: 'pytest - svmlib' + + - script: | + export PYTHONPATH=. + pytest tests/utils --durations=0 + displayName: 'pytest - utils' + + - script: | + export PYTHONPATH=. + pytest tests/xgboost --durations=0 + displayName: 'pytest - xgboost' + + - script: | + export PYTHONPATH=. + pip install h2o + pytest tests/h2o --durations=0 + displayName: 'pytest - h2o' + + - script: | + pip install torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html + pip install hummingbird-ml --no-deps + displayName: 'Install hummingbird-ml' + + - script: | + export PYTHONPATH=. + pytest tests/hummingbirdml --durations=0 + displayName: 'pytest - hummingbirdml' - task: PublishTestResults@2 inputs: diff --git a/.azure-pipelines/win32-CI-nightly.yml b/.azure-pipelines/win32-CI-nightly.yml index 1176a5ad..3aad5d61 100644 --- a/.azure-pipelines/win32-CI-nightly.yml +++ b/.azure-pipelines/win32-CI-nightly.yml @@ -45,7 +45,6 @@ jobs: pip install %COREML_PATH% %ONNX_PATH% pip install humming-bird-ml --no-deps pip install -r requirements.txt - python -m pip install torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio===0.8.1 -f https://download.pytorch.org/whl/torch_stable.html pip install -r requirements-dev.txt pip install %ONNXRT_PATH% displayName: 'Install dependencies' diff --git a/.azure-pipelines/win32-conda-CI.yml b/.azure-pipelines/win32-conda-CI.yml index 0e073eb5..c1329c7b 100644 --- a/.azure-pipelines/win32-conda-CI.yml +++ b/.azure-pipelines/win32-conda-CI.yml @@ -9,6 +9,7 @@ trigger: jobs: - job: 'Test' + timeoutInMinutes: 30 pool: vmImage: 'windows-latest' strategy: @@ -18,79 +19,63 @@ jobs: ONNX_PATH: 'onnx==1.10.1' # '-i https://test.pypi.org/simple/ onnx==1.9.101' ONNXRT_PATH: onnxruntime==1.8.1 COREML_PATH: git+https://github.com/apple/coremltools@3.1 - sklearn.version: '' Python39-190-RT181: python.version: '3.9' ONNX_PATH: 'onnx==1.9.0' ONNXRT_PATH: onnxruntime==1.8.1 COREML_PATH: git+https://github.com/apple/coremltools@3.1 - sklearn.version: '' Python39-190-RT180: python.version: '3.9' ONNX_PATH: onnx==1.9.0 ONNXRT_PATH: onnxruntime==1.8.0 COREML_PATH: git+https://github.com/apple/coremltools@3.1 - sklearn.version: '' Python38-181-RT170: python.version: '3.8' ONNX_PATH: onnx==1.8.1 ONNXRT_PATH: onnxruntime==1.7.0 COREML_PATH: git+https://github.com/apple/coremltools@3.1 - sklearn.version: '' Python37-180-RT160: python.version: '3.7' ONNX_PATH: onnx==1.8.0 ONNXRT_PATH: onnxruntime==1.6.0 COREML_PATH: git+https://github.com/apple/coremltools@3.1 - sklearn.version: '' Python37-160-RT111: python.version: '3.7' ONNX_PATH: onnx==1.6.0 ONNXRT_PATH: onnxruntime==1.1.1 COREML_PATH: git+https://github.com/apple/coremltools@3.1 - sklearn.version: '' Python37-170-RT130: python.version: '3.7' ONNX_PATH: onnx==1.7.0 ONNXRT_PATH: onnxruntime==1.3.0 COREML_PATH: git+https://github.com/apple/coremltools@3.1 - sklearn.version: '' maxParallel: 3 steps: - - task: UsePythonVersion@0 - inputs: - versionSpec: '$(python.version)' - architecture: 'x64' - - powershell: Write-Host "##vso[task.prependpath]$env:CONDA\Scripts" displayName: Add conda to PATH - - script: conda create --yes --quiet --name py$(python.version) -c conda-forge python=$(python.version) numpy protobuf + - script: conda create --yes --quiet --name py$(python.version) -c conda-forge python=$(python.version) numpy protobuf scikit-learn scipy cython displayName: Create Anaconda environment - script: | call activate py$(python.version) python -m pip install --upgrade pip numpy echo Test numpy installation... && python -c "import numpy" - python -m pip install scikit-learn - python -m pip install %ONNX_PATH% - python -m pip install humming-bird-ml --no-deps python -m pip install -r requirements.txt - python -m pip install torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio===0.8.1 -f https://download.pytorch.org/whl/torch_stable.html displayName: 'Install dependencies (1)' - script: | call activate py$(python.version) python -m pip install -r requirements-dev.txt - displayName: 'Install dependencies (2)' + displayName: 'Install dependencies-dev' - script: | call activate py$(python.version) @@ -99,14 +84,10 @@ jobs: - script: | call activate py$(python.version) + python -m pip install %ONNX_PATH% python -m pip install %ONNXRT_PATH% displayName: 'Install onnxruntime' - - script: | - call activate py$(python.version) - python -m pip install scikit-learn$(sklearn.version) - displayName: 'Install scikit-learn' - - script: | call activate py$(python.version) python -m flake8 ./onnxmltools @@ -118,8 +99,67 @@ jobs: export PYTHONPATH=. python -c "import onnxconverter_common;print(onnxconverter_common.__version__)" python -c "import onnxruntime;print(onnxruntime.__version__)" - python -m pytest tests --doctest-modules --junitxml=junit/test-results.xml - displayName: 'pytest - onnxmltools' + displayName: 'version' + + - script: | + call activate py$(python.version) + export PYTHONPATH=. + python -m pytest tests/baseline --durations=0 + displayName: 'pytest baseline' + + - script: | + call activate py$(python.version) + export PYTHONPATH=. + python -m pytest tests/catboost --durations=0 + displayName: 'pytest catboost' + + - script: | + call activate py$(python.version) + export PYTHONPATH=. + python -m pytest tests/coreml --durations=0 + displayName: 'pytest coreml' + + - script: | + call activate py$(python.version) + export PYTHONPATH=. + python -m pytest tests/lightgbm --durations=0 + displayName: 'pytest lightgbm' + + - script: | + call activate py$(python.version) + export PYTHONPATH=. + python -m pytest tests/sparkml --durations=0 + displayName: 'pytest sparkml' + + - script: | + call activate py$(python.version) + export PYTHONPATH=. + python -m pytest tests/svmlib --durations=0 + displayName: 'pytest svmlib' + + - script: | + call activate py$(python.version) + export PYTHONPATH=. + python -m pytest tests/utils --durations=0 + displayName: 'pytest utils' + + - script: | + call activate py$(python.version) + export PYTHONPATH=. + python -m pytest tests/xgboost --durations=0 + displayName: 'pytest xgboost' + + - script: | + call activate py$(python.version) + python -m pip install torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio===0.8.1 -f https://download.pytorch.org/whl/torch_stable.html + python -m pip install hummingbird-ml --no-deps + displayName: 'Install hummingbird-ml' + + - script: | + call activate py$(python.version) + export PYTHONPATH=. + python -m pytest tests/hummingbirdml --durations=0 + displayName: 'pytest hummingbirdml' - task: PublishTestResults@2 inputs: diff --git a/onnxmltools/convert/lightgbm/_parse.py b/onnxmltools/convert/lightgbm/_parse.py index fd68e6f1..011f6d81 100644 --- a/onnxmltools/convert/lightgbm/_parse.py +++ b/onnxmltools/convert/lightgbm/_parse.py @@ -81,17 +81,20 @@ def _get_lightgbm_operator_name(model): return lightgbm_operator_name_map[model_type] -def _parse_lightgbm_simple_model(scope, model, inputs): +def _parse_lightgbm_simple_model(scope, model, inputs, split=None): ''' This function handles all non-pipeline models. :param scope: Scope object :param model: A lightgbm object :param inputs: A list of variables + :param split: split TreeEnsembleRegressor into multiple node to reduce + discrepancies :return: A list of output variables which will be passed to next stage ''' operator_name = _get_lightgbm_operator_name(model) this_operator = scope.declare_local_operator(operator_name, model) + this_operator.split = split this_operator.inputs = inputs if operator_name == 'LgbmClassifier': @@ -151,7 +154,7 @@ def _parse_sklearn_classifier(scope, model, inputs, zipmap=True): return this_operator.outputs -def _parse_lightgbm(scope, model, inputs, zipmap=True): +def _parse_lightgbm(scope, model, inputs, zipmap=True, split=None): ''' This is a delegate function. It doesn't nothing but invoke the correct parsing function according to the input model's type. @@ -159,6 +162,8 @@ def _parse_lightgbm(scope, model, inputs, zipmap=True): :param model: A lightgbm object :param inputs: A list of variables :param zipmap: add operator ZipMap after operator TreeEnsembleClassifier + :param split: split TreeEnsembleRegressor into multiple node to reduce + discrepancies :return: The output variables produced by the input model ''' if isinstance(model, LGBMClassifier): @@ -166,12 +171,12 @@ def _parse_lightgbm(scope, model, inputs, zipmap=True): if (isinstance(model, WrappedBooster) and model.operator_name == 'LgbmClassifier'): return _parse_sklearn_classifier(scope, model, inputs, zipmap=zipmap) - return _parse_lightgbm_simple_model(scope, model, inputs) + return _parse_lightgbm_simple_model(scope, model, inputs, split=split) def parse_lightgbm(model, initial_types=None, target_opset=None, custom_conversion_functions=None, custom_shape_calculators=None, - zipmap=True): + zipmap=True, split=None): raw_model_container = LightGbmModelContainer(model) topology = Topology(raw_model_container, default_batch_size='None', initial_types=initial_types, target_opset=target_opset, @@ -186,7 +191,7 @@ def parse_lightgbm(model, initial_types=None, target_opset=None, for variable in inputs: raw_model_container.add_input(variable) - outputs = _parse_lightgbm(scope, model, inputs, zipmap=zipmap) + outputs = _parse_lightgbm(scope, model, inputs, zipmap=zipmap, split=split) for variable in outputs: raw_model_container.add_output(variable) diff --git a/onnxmltools/convert/lightgbm/convert.py b/onnxmltools/convert/lightgbm/convert.py index a5cfc893..e8ccf3e2 100644 --- a/onnxmltools/convert/lightgbm/convert.py +++ b/onnxmltools/convert/lightgbm/convert.py @@ -14,7 +14,8 @@ def convert(model, name=None, initial_types=None, doc_string='', target_opset=None, targeted_onnx=onnx.__version__, custom_conversion_functions=None, - custom_shape_calculators=None, without_onnx_ml=False, zipmap=True): + custom_shape_calculators=None, without_onnx_ml=False, zipmap=True, + split=None): ''' This function produces an equivalent ONNX model of the given lightgbm model. The supported lightgbm modules are listed below. @@ -34,6 +35,16 @@ def convert(model, name=None, initial_types=None, doc_string='', target_opset=No :param custom_shape_calculators: a dictionary for specifying the user customized shape calculator :param without_onnx_ml: whether to generate a model composed by ONNX operators only, or to allow the converter :param zipmap: remove operator ZipMap from the ONNX graph + :param split: this parameter is usefull to reduce the level of discrepancies for + big regression forest (number of trees > 100). lightgbm does all the computation + with double whereas ONNX is using floats. Instead of having one single node + TreeEnsembleRegressor, the converter splits it into multiple nodes TreeEnsembleRegressor, + casts the output in double and before additioning all the outputs. + The final graph is slower but keeps the discrepancies constant + (it is proportional to the number of trees in a node TreeEnsembleRegressor). + Parameter *split* is the number of trees per node. It could be possible to + do the same with TreeEnsembleClassifier. However, the normalization of the + probabilities significantly reduces the discrepancies. to use ONNX-ML operators as well. :return: An ONNX model (type: ModelProto) which is equivalent to the input lightgbm model ''' @@ -42,8 +53,8 @@ def convert(model, name=None, initial_types=None, doc_string='', target_opset=No 'onnxmltools.convert.lightgbm.convert for details') if without_onnx_ml and not hummingbird_installed(): raise RuntimeError( - 'Hummingbird is not installed. Please install hummingbird to use this feature: pip install hummingbird-ml' - ) + 'Hummingbird is not installed. Please install hummingbird to use this feature: ' + 'pip install hummingbird-ml') if isinstance(model, lightgbm.Booster): model = WrappedBooster(model) if name is None: @@ -51,11 +62,14 @@ def convert(model, name=None, initial_types=None, doc_string='', target_opset=No target_opset = target_opset if target_opset else get_maximum_opset_supported() topology = parse_lightgbm(model, initial_types, target_opset, custom_conversion_functions, - custom_shape_calculators, zipmap=zipmap) + custom_shape_calculators, zipmap=zipmap, split=split) topology.compile() onnx_ml_model = convert_topology(topology, name, doc_string, target_opset, targeted_onnx) if without_onnx_ml: + if zipmap: + raise NotImplementedError( + "Conversion with zipmap operator is not implemented with hummingbird-ml.") from hummingbird.ml import convert, constants extra_config = {} # extra_config[constants.ONNX_INITIAL_TYPES] = initial_types diff --git a/onnxmltools/convert/lightgbm/operator_converters/LightGbm.py b/onnxmltools/convert/lightgbm/operator_converters/LightGbm.py index 03150c82..423de707 100644 --- a/onnxmltools/convert/lightgbm/operator_converters/LightGbm.py +++ b/onnxmltools/convert/lightgbm/operator_converters/LightGbm.py @@ -6,8 +6,9 @@ import ctypes import json import numpy as np +from onnx import TensorProto from ...common._apply_operation import ( - apply_div, apply_reshape, apply_sub, apply_cast, apply_identity, apply_clip) + apply_div, apply_reshape, apply_sub, apply_cast, apply_identity) from ...common._registration import register_converter from ...common.tree_ensemble import get_default_tree_classifier_attribute_pairs from ....proto import onnx_proto @@ -364,6 +365,64 @@ def hook(self, obj): return ret, info +def _split_tree_ensemble_atts(attrs, split): + """ + Splits the attributes of a TreeEnsembleRegressor into + multiple trees in order to do the summation in double instead of floats. + """ + trees_id = list(sorted(set(attrs['nodes_treeids']))) + results = [] + index = 0 + while index < len(trees_id): + index2 = min(index + split, len(trees_id)) + subset = set(trees_id[index: index2]) + + indices_node = [] + indices_target = [] + for j, v in enumerate(attrs['nodes_treeids']): + if v in subset: + indices_node.append(j) + for j, v in enumerate(attrs['target_treeids']): + if v in subset: + indices_target.append(j) + + if (len(indices_node) >= len(attrs['nodes_treeids']) or + len(indices_target) >= len(attrs['target_treeids'])): + raise RuntimeError( # pragma: no cover + "Initial attributes are not consistant." + "\nindex=%r index2=%r subset=%r" + "\nnodes_treeids=%r\ntarget_treeids=%r" + "\nindices_node=%r\nindices_target=%r" % ( + index, index2, subset, + attrs['nodes_treeids'], attrs['target_treeids'], + indices_node, indices_target)) + + ats = {} + for name, att in attrs.items(): + if name == 'nodes_treeids': + new_att = [att[i] for i in indices_node] + new_att = [i - att[0] for i in new_att] + elif name == 'target_treeids': + new_att = [att[i] for i in indices_target] + new_att = [i - att[0] for i in new_att] + elif name.startswith("nodes_"): + new_att = [att[i] for i in indices_node] + assert len(new_att) == len(indices_node) + elif name.startswith("target_"): + new_att = [att[i] for i in indices_target] + assert len(new_att) == len(indices_target) + elif name == 'name': + new_att = "%s%d" % (att, len(results)) + else: + new_att = att + ats[name] = new_att + + results.append(ats) + index = index2 + + return results + + def convert_lightgbm(scope, operator, container): """ Converters for *lightgbm*. @@ -541,9 +600,32 @@ def convert_lightgbm(scope, operator, container): # and TreeEnsembleClassifier have different ONNX attributes attrs['target' + k[5:]] = copy.deepcopy(attrs[k]) del attrs[k] - container.add_node( - 'TreeEnsembleRegressor', operator.input_full_names, - output_name, op_domain='ai.onnx.ml', **attrs) + + split = getattr(operator, 'split', None) + if split in (None, -1): + container.add_node( + 'TreeEnsembleRegressor', operator.input_full_names, + output_name, op_domain='ai.onnx.ml', **attrs) + else: + tree_attrs = _split_tree_ensemble_atts(attrs, split) + tree_nodes = [] + for i, ats in enumerate(tree_attrs): + tree_name = scope.get_unique_variable_name('tree%d' % i) + container.add_node( + 'TreeEnsembleRegressor', operator.input_full_names, + tree_name, op_domain='ai.onnx.ml', **ats) + cast_name = scope.get_unique_variable_name('dtree%d' % i) + container.add_node( + 'Cast', tree_name, cast_name, to=TensorProto.DOUBLE, # pylint: disable=E1101 + name=scope.get_unique_operator_name("dtree%d" % i)) + tree_nodes.append(cast_name) + cast_name = scope.get_unique_variable_name('ftrees') + container.add_node( + 'Sum', tree_nodes, cast_name, + name=scope.get_unique_operator_name("sumtree%d" % len(tree_nodes))) + container.add_node( + 'Cast', cast_name, output_name, to=TensorProto.FLOAT, # pylint: disable=E1101 + name=scope.get_unique_operator_name("dtree%d" % i)) if gbm_model.boosting_type == 'rf': denominator_name = scope.get_unique_variable_name('denominator') @@ -722,12 +804,15 @@ def convert_lgbm_zipmap(scope, operator, container): operator.outputs[1].full_name, op_domain='ai.onnx.ml', **zipmap_attrs) else: - # This should be apply_identity but optimization fails in - # onnxconverter-common when trying to remove identity nodes. - apply_clip(scope, operator.inputs[1].full_name, - operator.outputs[1].full_name, container, - min=np.array([0], dtype=np.float32), - max=np.array([1], dtype=np.float32)) + # onnxconverter-common when trying to remove identity nodes + # if node identity is used. + one = scope.get_unique_variable_name('one') + + container.add_initializer( + one, onnx_proto.TensorProto.FLOAT, [], [1]) + container.add_node( + 'Mul', [operator.inputs[1].full_name, one], + operator.outputs[1].full_name) register_converter('LgbmClassifier', convert_lightgbm) diff --git a/onnxmltools/convert/main.py b/onnxmltools/convert/main.py index 7f19d137..50ec11ea 100644 --- a/onnxmltools/convert/main.py +++ b/onnxmltools/convert/main.py @@ -126,7 +126,8 @@ def convert_catboost(model, name=None, initial_types=None, doc_string='', target def convert_lightgbm(model, name=None, initial_types=None, doc_string='', target_opset=None, targeted_onnx=None, custom_conversion_functions=None, - custom_shape_calculators=None, without_onnx_ml=False, zipmap=True): + custom_shape_calculators=None, without_onnx_ml=False, zipmap=True, + split=None): if targeted_onnx is not None: warnings.warn("targeted_onnx is deprecated. Use target_opset.", DeprecationWarning) if not utils.lightgbm_installed(): @@ -135,7 +136,7 @@ def convert_lightgbm(model, name=None, initial_types=None, doc_string='', target from .lightgbm.convert import convert return convert(model, name, initial_types, doc_string, target_opset, targeted_onnx, custom_conversion_functions, custom_shape_calculators, without_onnx_ml, - zipmap=zipmap) + zipmap=zipmap, split=split) def convert_sklearn(model, name=None, initial_types=None, doc_string='', target_opset=None, diff --git a/onnxmltools/utils/utils_backend_onnxruntime.py b/onnxmltools/utils/utils_backend_onnxruntime.py index d8edb1a0..74622069 100644 --- a/onnxmltools/utils/utils_backend_onnxruntime.py +++ b/onnxmltools/utils/utils_backend_onnxruntime.py @@ -161,7 +161,7 @@ def to_array(vv): smodel = "\nJSON ONNX\n" + str(model) else: smodel = "" - raise OnnxRuntimeAssertionError("Model '{0}' has discrepencies.\n{1}: {2}{3}".format(onx, type(e), e, smodel)) + raise OnnxRuntimeAssertionError("Model '{0}' has discrepancies.\n{1}: {2}{3}".format(onx, type(e), e, smodel)) return output0 diff --git a/requirements-dev.txt b/requirements-dev.txt index 8f4d8ab0..897b7ca4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,12 +1,9 @@ --f https://download.pytorch.org/whl/torch_stable.html catboost codecov cython dill flake8 flatbuffers -h2o -hummingbird-ml libsvm lightgbm mleap @@ -21,7 +18,5 @@ pytest-cov pytest-spark scikit-learn scipy -tensorflow -torch wheel xgboost diff --git a/tests/h2o/test_h2o_converters.py b/tests/h2o/test_h2o_converters.py index b3df8770..dac438f6 100644 --- a/tests/h2o/test_h2o_converters.py +++ b/tests/h2o/test_h2o_converters.py @@ -14,6 +14,7 @@ import h2o from h2o import H2OFrame from h2o.estimators.gbm import H2OGradientBoostingEstimator +from h2o.exceptions import H2OConnectionError from h2o.estimators.random_forest import H2ORandomForestEstimator from onnxmltools.convert import convert_h2o from onnxmltools.utils import dump_data_and_model diff --git a/tests/hummingbirdml/test_LightGbmTreeEnsembleConverters_hummingbird.py b/tests/hummingbirdml/test_LightGbmTreeEnsembleConverters_hummingbird.py new file mode 100644 index 00000000..42b270f6 --- /dev/null +++ b/tests/hummingbirdml/test_LightGbmTreeEnsembleConverters_hummingbird.py @@ -0,0 +1,275 @@ +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from distutils.version import StrictVersion + +import lightgbm +import numpy +from numpy.testing import assert_almost_equal +from onnx.defs import onnx_opset_version +from lightgbm import LGBMClassifier, LGBMRegressor +from onnxruntime import InferenceSession +from onnxmltools.convert.common.utils import hummingbird_installed +from onnxmltools.convert.common.data_types import FloatTensorType +from onnxmltools.convert import convert_lightgbm +from onnxmltools.utils import dump_data_and_model +from onnxmltools.utils import dump_binary_classification, dump_multiple_classification +from onnxmltools.utils import dump_single_regression +from onnxmltools.utils.tests_helper import convert_model + +TARGET_OPSET = min(13, onnx_opset_version()) + + +class TestLightGbmTreeEnsembleModelsHummingBird(unittest.TestCase): + + @classmethod + def setUpClass(cls): + print('BEGIN.') + import torch + print(torch.device("cuda:0" if torch.cuda.is_available() else "cpu")) + + @classmethod + def tearDownClass(cls): + print("END.") + + # Tests with ONNX operators only + @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") + def test_lightgbm_booster_classifier(self): + X = [[0, 1], [1, 1], [2, 0], [1, 2]] + X = numpy.array(X, dtype=numpy.float32) + y = [0, 1, 0, 1] + data = lightgbm.Dataset(X, label=y) + model = lightgbm.train({'boosting_type': 'gbdt', 'objective': 'binary', + 'n_estimators': 3, 'min_child_samples': 1, 'num_thread': 1}, + data) + model_onnx, prefix = convert_model(model, 'tree-based classifier', + [('input', FloatTensorType([None, 2]))], without_onnx_ml=True, + target_opset=TARGET_OPSET, + zipmap=False) + dump_data_and_model(X, model, model_onnx, + allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')", + basename=prefix + "BoosterBin" + model.__class__.__name__) + + @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") + def test_lightgbm_booster_classifier_zipmap(self): + X = [[0, 1], [1, 1], [2, 0], [1, 2]] + X = numpy.array(X, dtype=numpy.float32) + y = [0, 1, 0, 1] + data = lightgbm.Dataset(X, label=y) + model = lightgbm.train({'boosting_type': 'gbdt', 'objective': 'binary', + 'n_estimators': 3, 'min_child_samples': 1, 'num_thread': 1}, + data) + model_onnx, prefix = convert_model(model, 'tree-based classifier', + [('input', FloatTensorType([None, 2]))], without_onnx_ml=False, + target_opset=TARGET_OPSET) + assert "zipmap" in str(model_onnx).lower() + with self.assertRaises(NotImplementedError): + convert_model(model, 'tree-based classifier', + [('input', FloatTensorType([None, 2]))], without_onnx_ml=True, + target_opset=TARGET_OPSET) + + model_onnx, prefix = convert_model(model, 'tree-based classifier', + [('input', FloatTensorType([None, 2]))], without_onnx_ml=True, + target_opset=TARGET_OPSET, zipmap=False) + dump_data_and_model(X, model, model_onnx, + allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')", + basename=prefix + "BoosterBin" + model.__class__.__name__) + + @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") + def test_lightgbm_booster_multi_classifier(self): + X = [[0, 1], [1, 1], [2, 0], [1, 2], [-1, 2], [1, -2]] + X = numpy.array(X, dtype=numpy.float32) + y = [0, 1, 0, 1, 2, 2] + data = lightgbm.Dataset(X, label=y) + model = lightgbm.train({'boosting_type': 'gbdt', 'objective': 'multiclass', + 'n_estimators': 3, 'min_child_samples': 1, 'num_class': 3, 'num_thread': 1}, + data) + model_onnx, prefix = convert_model(model, 'tree-based classifier', + [('input', FloatTensorType([None, 2]))], without_onnx_ml=True, + target_opset=TARGET_OPSET, zipmap=False) + dump_data_and_model(X, model, model_onnx, + allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')", + basename=prefix + "BoosterBin" + model.__class__.__name__) + sess = InferenceSession(model_onnx.SerializeToString()) + out = sess.get_outputs() + names = [o.name for o in out] + assert names == ['label', 'probabilities'] + + @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") + def test_lightgbm_booster_regressor(self): + X = [[0, 1], [1, 1], [2, 0]] + X = numpy.array(X, dtype=numpy.float32) + y = [0, 1, 1.1] + data = lightgbm.Dataset(X, label=y) + model = lightgbm.train({'boosting_type': 'gbdt', 'objective': 'regression', + 'n_estimators': 3, 'min_child_samples': 1, 'max_depth': 1}, + data) + model_onnx, prefix = convert_model(model, 'tree-based binary regressor', + [('input', FloatTensorType([None, 2]))], without_onnx_ml=True, + target_opset=TARGET_OPSET, zipmap=False) + dump_data_and_model(X, model, model_onnx, + allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.0.0')", + basename=prefix + "BoosterBin" + model.__class__.__name__) + + # Base test implementation comparing ONNXML and ONNX models. + def _test_lgbm(self, X, model, extra_config={}): + # Create ONNX-ML model + onnx_ml_model = convert_model( + model, 'lgbm-onnxml', [("input", FloatTensorType([None, X.shape[1]]))], + target_opset=TARGET_OPSET)[0] + + # Create ONNX model + onnx_model = convert_model( + model, 'lgbm-onnx', [("input", FloatTensorType([None, X.shape[1]]))], without_onnx_ml=True, + target_opset=TARGET_OPSET)[0] + + # Get the predictions for the ONNX-ML model + session = InferenceSession(onnx_ml_model.SerializeToString()) + output_names = [session.get_outputs()[i].name for i in range(len(session.get_outputs()))] + onnx_ml_pred = [[] for i in range(len(output_names))] + inputs = {session.get_inputs()[0].name: X} + pred = session.run(output_names, inputs) + for i in range(len(output_names)): + if output_names[i] == "label": + onnx_ml_pred[1] = pred[i] + else: + onnx_ml_pred[0] = pred[i] + + # Get the predictions for the ONNX model + session = InferenceSession(onnx_model.SerializeToString()) + onnx_pred = [[] for i in range(len(output_names))] + pred = session.run(output_names, inputs) + for i in range(len(output_names)): + if output_names[i] == "label": + onnx_pred[1] = pred[i] + else: + onnx_pred[0] = pred[i] + + return onnx_ml_pred, onnx_pred, output_names + + # Utility function for testing regression models. + def _test_regressor(self, X, model, rtol=1e-06, atol=1e-06, extra_config={}): + onnx_ml_pred, onnx_pred, output_names = self._test_lgbm(X, model, extra_config) + + # Check that predicted values match + numpy.testing.assert_allclose(onnx_ml_pred[0], onnx_pred[0], rtol=rtol, atol=atol) + + # Utility function for testing classification models. + def _test_classifier(self, X, model, rtol=1e-06, atol=1e-06, extra_config={}): + onnx_ml_pred, onnx_pred, output_names = self._test_lgbm(X, model, extra_config) + + numpy.testing.assert_allclose(onnx_ml_pred[1], onnx_pred[1], rtol=rtol, atol=atol) # labels + numpy.testing.assert_allclose( + list(map(lambda x: list(x.values()), onnx_ml_pred[0])), onnx_pred[0], rtol=rtol, atol=atol + ) # probs + + # Regression test with 3 estimators. + @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") + def _test_lightgbm_regressor(self): + X = [[0, 1], [1, 1], [2, 0], [4, 0], [2, 3]] + X = numpy.array(X, dtype=numpy.float32) + y = numpy.array([100, -10, 50, 10, 10], dtype=numpy.float32) + model = LGBMRegressor(n_estimators=3, min_child_samples=1, num_thread=1) + model.fit(X, y) + self._test_regressor(X, model) + + # Regression test with 1 estimator. + @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") + def _test_lightgbm_regressor1(self): + model = LGBMRegressor(n_estimators=1, min_child_samples=1, num_thread=1) + X = [[0, 1], [1, 1], [2, 0]] + X = numpy.array(X, dtype=numpy.float32) + y = numpy.array([100, -10, 50], dtype=numpy.float32) + model.fit(X, y) + self._test_regressor(X, model) + + # Regression test with 2 estimators. + @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") + def _test_lightgbm_regressor2(self): + model = LGBMRegressor(n_estimators=2, max_depth=1, min_child_samples=1, num_thread=1) + X = [[0, 1], [1, 1], [2, 0]] + X = numpy.array(X, dtype=numpy.float32) + y = numpy.array([100, -10, 50], dtype=numpy.float32) + model.fit(X, y) + self._test_regressor(X, model) + + # Regression test with gbdt boosting type. + @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") + def _test_lightgbm_booster_regressor(self): + X = [[0, 1], [1, 1], [2, 0]] + X = numpy.array(X, dtype=numpy.float32) + y = [0, 1, 1.1] + data = lightgbm.Dataset(X, label=y) + model = lightgbm.train( + {"boosting_type": "gbdt", "objective": "regression", "n_estimators": 3, + "min_child_samples": 1, "max_depth": 1, 'num_thread': 1}, + data, + ) + self._test_regressor(X, model) + + # Binary classification test with 3 estimators. + @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") + def _test_lightgbm_classifier(self): + model = LGBMClassifier(n_estimators=3, min_child_samples=1, num_thread=1) + X = [[0, 1], [1, 1], [2, 0]] + X = numpy.array(X, dtype=numpy.float32) + y = [0, 1, 0] + model.fit(X, y) + self._test_classifier(X, model) + + # Binary classification test with 3 estimators zipmap. + @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") + def _test_lightgbm_classifier_zipmap(self): + X = [[0, 1], [1, 1], [2, 0], [1, 2]] + X = numpy.array(X, dtype=numpy.float32) + y = [0, 1, 0, 1] + model = LGBMClassifier(n_estimators=3, min_child_samples=1, num_thread=1) + model.fit(X, y) + self._test_classifier(X, model) + + # Binary classification test with 3 estimators and selecting boosting type. + @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") + def _test_lightgbm_booster_classifier(self): + X = [[0, 1], [1, 1], [2, 0], [1, 2]] + X = numpy.array(X, dtype=numpy.float32) + y = [0, 1, 0, 1] + data = lightgbm.Dataset(X, label=y) + model = lightgbm.train({"boosting_type": "gbdt", "objective": "binary", "n_estimators": 3, "min_child_samples": 1, 'num_thread': 1}, data) + self._test_classifier(X, model) + + # Binary classification test with 3 estimators and selecting boosting type zipmap. + @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") + def _test_lightgbm_booster_classifier_zipmap(self): + X = [[0, 1], [1, 1], [2, 0], [1, 2]] + X = numpy.array(X, dtype=numpy.float32) + y = [0, 1, 0, 1] + data = lightgbm.Dataset(X, label=y) + model = lightgbm.train({"boosting_type": "gbdt", "objective": "binary", "n_estimators": 3, "min_child_samples": 1, 'num_thread': 1}, data) + self._test_classifier(X, model) + + # Multiclass classification test with 3 estimators. + @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") + def _test_lightgbm_classifier_multi(self): + model = LGBMClassifier(n_estimators=3, min_child_samples=1, num_thread=1) + X = [[0, 1], [1, 1], [2, 0], [0.5, 0.5], [1.1, 1.1], [2.1, 0.1]] + X = numpy.array(X, dtype=numpy.float32) + y = [0, 1, 2, 1, 1, 2] + model.fit(X, y) + self._test_classifier(X, model) + + # Multiclass classification test with 3 estimators and selecting boosting type. + @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") + def _test_lightgbm_booster_multi_classifier(self): + X = [[0, 1], [1, 1], [2, 0], [1, 2], [-1, 2], [1, -2]] + X = numpy.array(X, dtype=numpy.float32) + y = [0, 1, 0, 1, 2, 2] + data = lightgbm.Dataset(X, label=y) + model = lightgbm.train( + {"boosting_type": "gbdt", "objective": "multiclass", "n_estimators": 3, "min_child_samples": 1, "num_class": 3, 'num_thread': 1}, + data, + ) + self._test_classifier(X, model) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/lightgbm/test_LightGbmTreeEnsembleConverters.py b/tests/lightgbm/test_LightGbmTreeEnsembleConverters.py index bcbfb191..0eba0cf4 100644 --- a/tests/lightgbm/test_LightGbmTreeEnsembleConverters.py +++ b/tests/lightgbm/test_LightGbmTreeEnsembleConverters.py @@ -9,7 +9,6 @@ from onnx.defs import onnx_opset_version from lightgbm import LGBMClassifier, LGBMRegressor import onnxruntime -from onnxmltools.convert.common.utils import hummingbird_installed from onnxmltools.convert.common.data_types import FloatTensorType from onnxmltools.convert import convert_lightgbm from onnxmltools.utils import dump_data_and_model @@ -23,7 +22,7 @@ class TestLightGbmTreeEnsembleModels(unittest.TestCase): def test_lightgbm_classifier(self): - model = LGBMClassifier(n_estimators=3, min_child_samples=1) + model = LGBMClassifier(n_estimators=3, min_child_samples=1, num_thread=1) dump_binary_classification(model, allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')") dump_multiple_classification(model, allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')") @@ -31,7 +30,7 @@ def test_lightgbm_classifier_zipmap(self): X = [[0, 1], [1, 1], [2, 0], [1, 2]] X = numpy.array(X, dtype=numpy.float32) y = [0, 1, 0, 1] - model = LGBMClassifier(n_estimators=3, min_child_samples=1) + model = LGBMClassifier(n_estimators=3, min_child_samples=1, num_thread=1) model.fit(X, y) onx = convert_model( model, 'dummy', input_types=[('X', FloatTensorType([None, X.shape[1]]))], @@ -42,7 +41,7 @@ def test_lightgbm_classifier_nozipmap(self): X = [[0, 1], [1, 1], [2, 0], [1, 2], [1, 5], [6, 2]] X = numpy.array(X, dtype=numpy.float32) y = [0, 1, 0, 1, 1, 0] - model = LGBMClassifier(n_estimators=3, min_child_samples=1, max_depth=2) + model = LGBMClassifier(n_estimators=3, min_child_samples=1, max_depth=2, num_thread=1) model.fit(X, y) onx = convert_model( model, 'dummy', input_types=[('X', FloatTensorType([None, X.shape[1]]))], @@ -64,7 +63,7 @@ def test_lightgbm_classifier_nozipmap2(self): X = [[0, 1], [1, 1], [2, 0], [1, 2], [1, 5], [6, 2]] X = numpy.array(X, dtype=numpy.float32) y = [0, 1, 0, 1, 1, 0] - model = LGBMClassifier(n_estimators=3, min_child_samples=1, max_depth=2) + model = LGBMClassifier(n_estimators=3, min_child_samples=1, max_depth=2, num_thread=1) model.fit(X, y) onx = convert_lightgbm( model, 'dummy', initial_types=[('X', FloatTensorType([None, X.shape[1]]))], @@ -83,15 +82,15 @@ def test_lightgbm_classifier_nozipmap2(self): assert_almost_equal(exp[1], got[1]) def test_lightgbm_regressor(self): - model = LGBMRegressor(n_estimators=3, min_child_samples=1) + model = LGBMRegressor(n_estimators=3, min_child_samples=1, num_thread=1) dump_single_regression(model) def test_lightgbm_regressor1(self): - model = LGBMRegressor(n_estimators=1, min_child_samples=1) + model = LGBMRegressor(n_estimators=1, min_child_samples=1, num_thread=1) dump_single_regression(model, suffix="1") def test_lightgbm_regressor2(self): - model = LGBMRegressor(n_estimators=2, max_depth=1, min_child_samples=1) + model = LGBMRegressor(n_estimators=2, max_depth=1, min_child_samples=1, num_thread=1) dump_single_regression(model, suffix="2") def test_lightgbm_booster_classifier(self): @@ -100,7 +99,7 @@ def test_lightgbm_booster_classifier(self): y = [0, 1, 0, 1] data = lightgbm.Dataset(X, label=y) model = lightgbm.train({'boosting_type': 'gbdt', 'objective': 'binary', - 'n_estimators': 3, 'min_child_samples': 1}, + 'n_estimators': 3, 'min_child_samples': 1, 'num_thread': 1}, data) model_onnx, prefix = convert_model(model, 'tree-based classifier', [('input', FloatTensorType([None, 2]))], @@ -115,7 +114,7 @@ def test_lightgbm_booster_classifier_nozipmap(self): y = [0, 1, 0, 1] data = lightgbm.Dataset(X, label=y) model = lightgbm.train({'boosting_type': 'gbdt', 'objective': 'binary', - 'n_estimators': 3, 'min_child_samples': 1}, + 'n_estimators': 3, 'min_child_samples': 1, 'num_thread': 1}, data) model_onnx, prefix = convert_model(model, 'tree-based classifier', [('input', FloatTensorType([None, 2]))], @@ -131,7 +130,7 @@ def test_lightgbm_booster_classifier_zipmap(self): y = [0, 1, 0, 1] data = lightgbm.Dataset(X, label=y) model = lightgbm.train({'boosting_type': 'gbdt', 'objective': 'binary', - 'n_estimators': 3, 'min_child_samples': 1}, + 'n_estimators': 3, 'min_child_samples': 1, 'num_thread': 1}, data) model_onnx, prefix = convert_model(model, 'tree-based classifier', [('input', FloatTensorType([None, 2]))], @@ -147,7 +146,7 @@ def test_lightgbm_booster_multi_classifier(self): y = [0, 1, 0, 1, 2, 2] data = lightgbm.Dataset(X, label=y) model = lightgbm.train({'boosting_type': 'gbdt', 'objective': 'multiclass', - 'n_estimators': 3, 'min_child_samples': 1, 'num_class': 3}, + 'n_estimators': 3, 'min_child_samples': 1, 'num_class': 3, 'num_thread': 1}, data) model_onnx, prefix = convert_model(model, 'tree-based classifier', [('input', FloatTensorType([None, 2]))]) @@ -170,248 +169,13 @@ def test_lightgbm_booster_regressor(self): y = [0, 1, 1.1] data = lightgbm.Dataset(X, label=y) model = lightgbm.train({'boosting_type': 'gbdt', 'objective': 'regression', - 'n_estimators': 3, 'min_child_samples': 1, 'max_depth': 1}, + 'n_estimators': 3, 'min_child_samples': 1, 'max_depth': 1, 'num_thread': 1}, data) model_onnx, prefix = convert_model(model, 'tree-based binary classifier', [('input', FloatTensorType([None, 2]))]) dump_data_and_model(X, model, model_onnx, basename=prefix + "BoosterBin" + model.__class__.__name__) - # Tests with ONNX operators only - @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") - def test_lightgbm_booster_classifier(self): - X = [[0, 1], [1, 1], [2, 0], [1, 2]] - X = numpy.array(X, dtype=numpy.float32) - y = [0, 1, 0, 1] - data = lightgbm.Dataset(X, label=y) - model = lightgbm.train({'boosting_type': 'gbdt', 'objective': 'binary', - 'n_estimators': 3, 'min_child_samples': 1}, - data) - model_onnx, prefix = convert_model(model, 'tree-based classifier', - [('input', FloatTensorType([None, 2]))], without_onnx_ml=True) - dump_data_and_model(X, model, model_onnx, - allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')", - basename=prefix + "BoosterBin" + model.__class__.__name__) - - @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") - def test_lightgbm_booster_classifier_zipmap(self): - X = [[0, 1], [1, 1], [2, 0], [1, 2]] - X = numpy.array(X, dtype=numpy.float32) - y = [0, 1, 0, 1] - data = lightgbm.Dataset(X, label=y) - model = lightgbm.train({'boosting_type': 'gbdt', 'objective': 'binary', - 'n_estimators': 3, 'min_child_samples': 1}, - data) - model_onnx, prefix = convert_model(model, 'tree-based classifier', - [('input', FloatTensorType([None, 2]))], without_onnx_ml=True) - assert "zipmap" in str(model_onnx).lower() - dump_data_and_model(X, model, model_onnx, - allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')", - basename=prefix + "BoosterBin" + model.__class__.__name__) - - @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") - def test_lightgbm_booster_multi_classifier(self): - X = [[0, 1], [1, 1], [2, 0], [1, 2], [-1, 2], [1, -2]] - X = numpy.array(X, dtype=numpy.float32) - y = [0, 1, 0, 1, 2, 2] - data = lightgbm.Dataset(X, label=y) - model = lightgbm.train({'boosting_type': 'gbdt', 'objective': 'multiclass', - 'n_estimators': 3, 'min_child_samples': 1, 'num_class': 3}, - data) - model_onnx, prefix = convert_model(model, 'tree-based classifier', - [('input', FloatTensorType([None, 2]))], without_onnx_ml=True) - dump_data_and_model(X, model, model_onnx, - allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')", - basename=prefix + "BoosterBin" + model.__class__.__name__) - try: - from onnxruntime import InferenceSession - except ImportError: - # onnxruntime not installed (python 2.7) - return - sess = InferenceSession(model_onnx.SerializeToString()) - out = sess.get_outputs() - names = [o.name for o in out] - assert names == ['label', 'probabilities'] - - @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") - def test_lightgbm_booster_regressor(self): - X = [[0, 1], [1, 1], [2, 0]] - X = numpy.array(X, dtype=numpy.float32) - y = [0, 1, 1.1] - data = lightgbm.Dataset(X, label=y) - model = lightgbm.train({'boosting_type': 'gbdt', 'objective': 'regression', - 'n_estimators': 3, 'min_child_samples': 1, 'max_depth': 1}, - data) - model_onnx, prefix = convert_model(model, 'tree-based binary classifier', - [('input', FloatTensorType([None, 2]))], without_onnx_ml=True) - dump_data_and_model(X, model, model_onnx, - allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.0.0')", - basename=prefix + "BoosterBin" + model.__class__.__name__) - - # Base test implementation comparing ONNXML and ONNX models. - def _test_lgbm(self, X, model, extra_config={}): - # Create ONNX-ML model - onnx_ml_model = convert_model( - model, 'lgbm-onnxml', [("input", FloatTensorType([X.shape[0], X.shape[1]]))], - target_opset=TARGET_OPSET)[0] - - # Create ONNX model - onnx_model = convert_model( - model, 'lgbm-onnx', [("input", FloatTensorType([X.shape[0], X.shape[1]]))], without_onnx_ml=True, - target_opset=TARGET_OPSET)[0] - - try: - from onnxruntime import InferenceSession - except ImportError: - # onnxruntime not installed (python 2.7) - return - - # Get the predictions for the ONNX-ML model - session = InferenceSession(onnx_ml_model.SerializeToString()) - output_names = [session.get_outputs()[i].name for i in range(len(session.get_outputs()))] - onnx_ml_pred = [[] for i in range(len(output_names))] - inputs = {session.get_inputs()[0].name: X} - pred = session.run(output_names, inputs) - for i in range(len(output_names)): - if output_names[i] == "label": - onnx_ml_pred[1] = pred[i] - else: - onnx_ml_pred[0] = pred[i] - - # Get the predictions for the ONNX model - session = InferenceSession(onnx_model.SerializeToString()) - onnx_pred = [[] for i in range(len(output_names))] - pred = session.run(output_names, inputs) - for i in range(len(output_names)): - if output_names[i] == "label": - onnx_pred[1] = pred[i] - else: - onnx_pred[0] = pred[i] - - return onnx_ml_pred, onnx_pred, output_names - - # Utility function for testing regression models. - def _test_regressor(self, X, model, rtol=1e-06, atol=1e-06, extra_config={}): - onnx_ml_pred, onnx_pred, output_names = self._test_lgbm(X, model, extra_config) - - # Check that predicted values match - numpy.testing.assert_allclose(onnx_ml_pred[0], onnx_pred[0], rtol=rtol, atol=atol) - - # Utility function for testing classification models. - def _test_classifier(self, X, model, rtol=1e-06, atol=1e-06, extra_config={}): - onnx_ml_pred, onnx_pred, output_names = self._test_lgbm(X, model, extra_config) - - numpy.testing.assert_allclose(onnx_ml_pred[1], onnx_pred[1], rtol=rtol, atol=atol) # labels - numpy.testing.assert_allclose( - list(map(lambda x: list(x.values()), onnx_ml_pred[0])), onnx_pred[0], rtol=rtol, atol=atol - ) # probs - - # Regression test with 3 estimators. - @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") - def test_lightgbm_regressor(self): - X = [[0, 1], [1, 1], [2, 0]] - X = numpy.array(X, dtype=numpy.float32) - y = numpy.array([100, -10, 50], dtype=numpy.float32) - model = LGBMRegressor(n_estimators=3, min_child_samples=1) - model.fit(X, y) - self._test_regressor(X, model) - - # Regression test with 1 estimator. - @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") - def test_lightgbm_regressor1(self): - model = LGBMRegressor(n_estimators=1, min_child_samples=1) - X = [[0, 1], [1, 1], [2, 0]] - X = numpy.array(X, dtype=numpy.float32) - y = numpy.array([100, -10, 50], dtype=numpy.float32) - model.fit(X, y) - self._test_regressor(X, model) - - # Regression test with 2 estimators. - @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") - def test_lightgbm_regressor2(self): - model = LGBMRegressor(n_estimators=2, max_depth=1, min_child_samples=1) - X = [[0, 1], [1, 1], [2, 0]] - X = numpy.array(X, dtype=numpy.float32) - y = numpy.array([100, -10, 50], dtype=numpy.float32) - model.fit(X, y) - self._test_regressor(X, model) - - # Regression test with gbdt boosting type. - @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") - def test_lightgbm_booster_regressor(self): - X = [[0, 1], [1, 1], [2, 0]] - X = numpy.array(X, dtype=numpy.float32) - y = [0, 1, 1.1] - data = lightgbm.Dataset(X, label=y) - model = lightgbm.train( - {"boosting_type": "gbdt", "objective": "regression", "n_estimators": 3, "min_child_samples": 1, "max_depth": 1}, - data, - ) - self._test_regressor(X, model) - - # Binary classification test with 3 estimators. - @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") - def test_lightgbm_classifier(self): - model = LGBMClassifier(n_estimators=3, min_child_samples=1) - X = [[0, 1], [1, 1], [2, 0]] - X = numpy.array(X, dtype=numpy.float32) - y = [0, 1, 0] - model.fit(X, y) - self._test_classifier(X, model) - - # Binary classification test with 3 estimators zipmap. - @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") - def test_lightgbm_classifier_zipmap(self): - X = [[0, 1], [1, 1], [2, 0], [1, 2]] - X = numpy.array(X, dtype=numpy.float32) - y = [0, 1, 0, 1] - model = LGBMClassifier(n_estimators=3, min_child_samples=1) - model.fit(X, y) - self._test_classifier(X, model) - - # Binary classification test with 3 estimators and selecting boosting type. - @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") - def test_lightgbm_booster_classifier(self): - X = [[0, 1], [1, 1], [2, 0], [1, 2]] - X = numpy.array(X, dtype=numpy.float32) - y = [0, 1, 0, 1] - data = lightgbm.Dataset(X, label=y) - model = lightgbm.train({"boosting_type": "gbdt", "objective": "binary", "n_estimators": 3, "min_child_samples": 1}, data) - self._test_classifier(X, model) - - # Binary classification test with 3 estimators and selecting boosting type zipmap. - @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") - def test_lightgbm_booster_classifier_zipmap(self): - X = [[0, 1], [1, 1], [2, 0], [1, 2]] - X = numpy.array(X, dtype=numpy.float32) - y = [0, 1, 0, 1] - data = lightgbm.Dataset(X, label=y) - model = lightgbm.train({"boosting_type": "gbdt", "objective": "binary", "n_estimators": 3, "min_child_samples": 1}, data) - self._test_classifier(X, model) - - # Multiclass classification test with 3 estimators. - @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") - def test_lightgbm_classifier_multi(self): - model = LGBMClassifier(n_estimators=3, min_child_samples=1) - X = [[0, 1], [1, 1], [2, 0], [0.5, 0.5], [1.1, 1.1], [2.1, 0.1]] - X = numpy.array(X, dtype=numpy.float32) - y = [0, 1, 2, 1, 1, 2] - model.fit(X, y) - self._test_classifier(X, model) - - # Multiclass classification test with 3 estimators and selecting boosting type. - @unittest.skipIf(not hummingbird_installed(), reason="Hummingbird is not installed") - def test_lightgbm_booster_multi_classifier(self): - X = [[0, 1], [1, 1], [2, 0], [1, 2], [-1, 2], [1, -2]] - X = numpy.array(X, dtype=numpy.float32) - y = [0, 1, 0, 1, 2, 2] - data = lightgbm.Dataset(X, label=y) - model = lightgbm.train( - {"boosting_type": "gbdt", "objective": "multiclass", "n_estimators": 3, "min_child_samples": 1, "num_class": 3}, - data, - ) - self._test_classifier(X, model) - if __name__ == "__main__": unittest.main() diff --git a/tests/lightgbm/test_LightGbmTreeEnsembleConverters_split.py b/tests/lightgbm/test_LightGbmTreeEnsembleConverters_split.py new file mode 100644 index 00000000..20858189 --- /dev/null +++ b/tests/lightgbm/test_LightGbmTreeEnsembleConverters_split.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from distutils.version import StrictVersion +import lightgbm +import numpy +from numpy.testing import assert_almost_equal +from onnx.defs import onnx_opset_version +from lightgbm import LGBMRegressor +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from onnxruntime import InferenceSession, __version__ as ort_version +from onnxmltools.convert.common.utils import hummingbird_installed +from onnxmltools.convert.common.data_types import FloatTensorType +from onnxmltools.convert import convert_lightgbm +from onnxmltools.utils import dump_data_and_model +from onnxmltools.utils import dump_binary_classification, dump_multiple_classification +from onnxmltools.utils import dump_single_regression +from onnxmltools.utils.tests_helper import convert_model + +TARGET_OPSET = min(13, onnx_opset_version()) + + +class TestLightGbmTreeEnsembleModelsSplit(unittest.TestCase): + + @unittest.skipIf(StrictVersion(ort_version) < StrictVersion('1.7.0'), + reason="Sum not implemented.") + def test_lgbm_regressor10(self): + data = load_iris() + X, y = data.data, data.target + X = X.astype(numpy.float32) + X_train, X_test, y_train, _ = train_test_split(X, y, random_state=0) + reg = LGBMRegressor(max_depth=2, n_estimators=4, seed=0, num_thread=1) + reg.fit(X_train, y_train) + expected = reg.predict(X_test) + + # float + init = [('X', FloatTensorType([None, X_train.shape[1]]))] + onx = convert_lightgbm(reg, None, init) + self.assertNotIn('op_type: "Sum"', str(onx)) + oinf = InferenceSession(onx.SerializeToString()) + got1 = oinf.run(None, {'X': X_test})[0] + + # float split + onx = convert_lightgbm(reg, None, init, split=2) + self.assertIn('op_type: "Sum"', str(onx)) + oinf = InferenceSession(onx.SerializeToString()) + got2 = oinf.run(None, {'X': X_test})[0] + + # final check + assert_almost_equal(expected, got1.ravel(), decimal=5) + assert_almost_equal(expected, got2.ravel(), decimal=5) + + @unittest.skipIf(StrictVersion(ort_version) < StrictVersion('1.7.0'), + reason="Sum not implemented.") + def test_lgbm_regressor(self): + data = load_iris() + X, y = data.data, data.target + X = X.astype(numpy.float32) + X_train, X_test, y_train, _ = train_test_split(X, y, random_state=0) + reg = LGBMRegressor(max_depth=2, n_estimators=100, seed=0, num_thread=1) + reg.fit(X_train, y_train) + expected = reg.predict(X_test) + + # float + init = [('X', FloatTensorType([None, X_train.shape[1]]))] + onx = convert_lightgbm(reg, None, init) + self.assertNotIn('op_type: "Sum"', str(onx)) + oinf = InferenceSession(onx.SerializeToString()) + got1 = oinf.run(None, {'X': X_test})[0] + assert_almost_equal(expected, got1.ravel(), decimal=5) + + # float split + onx = convert_lightgbm(reg, None, init, split=10) + self.assertIn('op_type: "Sum"', str(onx)) + oinf = InferenceSession(onx.SerializeToString()) + got2 = oinf.run(None, {'X': X_test})[0] + assert_almost_equal(expected, got2.ravel(), decimal=5) + + # final + d1 = numpy.abs(expected.ravel() - got1.ravel()).mean() + d2 = numpy.abs(expected.ravel() - got2.ravel()).mean() + self.assertGreater(d1, d2) + + @unittest.skipIf(StrictVersion(ort_version) < StrictVersion('1.7.0'), + reason="Sum not implemented.") + def test_lightgbm_booster_regressor(self): + data = load_iris() + X, y = data.data, data.target + X_train, X_test, y_train, _ = train_test_split(X, y, random_state=0) + data = lightgbm.Dataset(X_train, label=y_train) + model = lightgbm.train({'boosting_type': 'gbdt', 'objective': 'regression', + 'n_estimators': 100, 'max_depth': 2, 'num_thread': 1}, + data) + expected = model.predict(X_test) + onx = convert_lightgbm(model, '', [('X', FloatTensorType([None, 4]))]) + onx10 = convert_lightgbm(model, '', [('X', FloatTensorType([None, 4]))], split=1) + + self.assertNotIn('op_type: "Sum"', str(onx)) + oinf = InferenceSession(onx.SerializeToString()) + got1 = oinf.run(None, {'X': X_test.astype(numpy.float32)})[0] + assert_almost_equal(expected, got1.ravel(), decimal=5) + + self.assertIn('op_type: "Sum"', str(onx10)) + oinf = InferenceSession(onx10.SerializeToString()) + got2 = oinf.run(None, {'X': X_test.astype(numpy.float32)})[0] + assert_almost_equal(expected, got2.ravel(), decimal=5) + + d1 = numpy.abs(expected.ravel() - got1.ravel()).mean() + d2 = numpy.abs(expected.ravel() - got2.ravel()).mean() + self.assertGreater(d1, d2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/lightgbm/test_lightgbm_missing_values.py b/tests/lightgbm/test_lightgbm_missing_values.py index 9236f59b..33205248 100644 --- a/tests/lightgbm/test_lightgbm_missing_values.py +++ b/tests/lightgbm/test_lightgbm_missing_values.py @@ -54,7 +54,7 @@ def test_missing_values(self): min_data_in_leaf=1, n_estimators=1, learning_rate=1, - ) + num_thread=1) regressor.fit(_X_train, _y) regressor_onnx: ModelProto = convert_lightgbm(regressor, initial_types=_INITIAL_TYPES) y_pred = regressor.predict(_X_test) @@ -65,3 +65,8 @@ def test_missing_values(self): decimal=_N_DECIMALS, frac=_FRAC, ) + + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/lightgbm/test_objective_functions.py b/tests/lightgbm/test_objective_functions.py index a49f6cac..9fc2db8a 100644 --- a/tests/lightgbm/test_objective_functions.py +++ b/tests/lightgbm/test_objective_functions.py @@ -76,7 +76,7 @@ def test_objective(self): """ for objective in self._objectives: with self.subTest(X=_X, objective=objective): - regressor = LGBMRegressor(objective=objective) + regressor = LGBMRegressor(objective=objective, num_thread=1) regressor.fit(_X, _Y) regressor_onnx: ModelProto = convert_lightgbm(regressor, initial_types=self._calc_initial_types(_X)) y_pred = regressor.predict(_X) @@ -87,3 +87,7 @@ def test_objective(self): decimal=_N_DECIMALS, frac=_FRAC, ) + + +if __name__ == "__main__": + unittest.main()