Skip to content

Commit

Permalink
Implements option zipmap for MultiOutputClassifier (#775)
Browse files Browse the repository at this point in the history
* Implements option zipmap for MultiOutputClassifier
  • Loading branch information
xadupre authored Nov 10, 2021
1 parent a26899f commit 813c861
Show file tree
Hide file tree
Showing 14 changed files with 423 additions and 70 deletions.
2 changes: 2 additions & 0 deletions .azure-pipelines/linux-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ jobs:
echo "---------------"
pip show numpy
echo "---------------"
pip show scipy
echo "---------------"
pip show pandas
echo "---------------"
pip show onnx
Expand Down
91 changes: 72 additions & 19 deletions skl2onnx/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,17 +413,9 @@ def _parse_sklearn_random_trees_embedding(scope, model, inputs,
return res


def _parse_sklearn_classifier(scope, model, inputs, custom_parsers=None):
options = scope.get_options(model, dict(zipmap=True))
no_zipmap = (
(isinstance(options['zipmap'], bool) and not options['zipmap']) or
(model.__class__ in [NuSVC, SVC] and not model.probability))
probability_tensor = _parse_sklearn_simple_model(
scope, model, inputs, custom_parsers=custom_parsers)
if no_zipmap:
return probability_tensor

if options['zipmap'] == 'columns':
def _apply_zipmap(zipmap_options, scope, model, input_type,
probability_tensor):
if zipmap_options == 'columns':
zipmap_operator = scope.declare_local_operator('SklearnZipMapColumns')
classes = get_label_classes(scope, model)
classes_names = get_label_classes(scope, model, node_names=True)
Expand Down Expand Up @@ -456,10 +448,11 @@ def _parse_sklearn_classifier(scope, model, inputs, custom_parsers=None):
label_type = StringTensorType([None])

zip_label = scope.declare_local_variable('output_label', label_type)
zipmap_operator.outputs.append(zip_label)
if len(probability_tensor) == 2:
zipmap_operator.outputs.append(zip_label)

if options['zipmap'] == 'columns':
prob_type = probability_tensor[1].type
if zipmap_options == 'columns':
prob_type = probability_tensor[-1].type
for cl in classes_names:
output_cl = scope.declare_local_variable(cl, prob_type.__class__())
zipmap_operator.outputs.append(output_cl)
Expand All @@ -468,24 +461,84 @@ def _parse_sklearn_classifier(scope, model, inputs, custom_parsers=None):
'output_probability',
SequenceType(
DictionaryType(
label_type, guess_tensor_type(inputs[0].type))))
label_type, guess_tensor_type(input_type))))
zipmap_operator.outputs.append(zip_probability)

zipmap_operator.init_status(is_evaluated=True)
return zipmap_operator.outputs


def _parse_sklearn_classifier(scope, model, inputs, custom_parsers=None):
options = scope.get_options(model, dict(zipmap=True))
no_zipmap = (
(isinstance(options['zipmap'], bool) and not options['zipmap']) or
(model.__class__ in [NuSVC, SVC] and not model.probability))
probability_tensor = _parse_sklearn_simple_model(
scope, model, inputs, custom_parsers=custom_parsers)
if no_zipmap:
return probability_tensor
return _apply_zipmap(
options['zipmap'], scope, model, inputs[0].type, probability_tensor)


def _parse_sklearn_multi_output_classifier(scope, model, inputs,
custom_parsers=None):
alias = _get_sklearn_operator_name(type(model))
this_operator = scope.declare_local_operator(alias, model)
this_operator.inputs = inputs
label = scope.declare_local_variable("label", Int64TensorType())
proba = scope.declare_local_variable(
"probabilities", SequenceType(guess_tensor_type(inputs[0].type)))
guessed_output_type = guess_tensor_type(inputs[0].type)
label_type = Int64TensorType()
label = scope.declare_local_variable("label", label_type)

options = scope.get_options(model, dict(zipmap=True))
no_zipmap = isinstance(options['zipmap'], bool) and not options['zipmap']

if no_zipmap:
proba = scope.declare_local_variable(
"probabilities", SequenceType(guessed_output_type))
this_operator.outputs.append(label)
this_operator.outputs.append(proba)
return this_operator.outputs

warnings.warn(
"The current converter for class %r "
"creates a sequence of maps as an output."
"This is not allowed in standards ONNX specifications."
"" % model.__class__.__name__,
RuntimeWarning)

this_operator.outputs.append(label)
proba = scope.declare_local_variable(
"output_probabilities", SequenceType(guessed_output_type))
this_operator.outputs.append(proba)
return this_operator.outputs

zipmap = options.get('zipmap', False)
if zipmap == 'columns':
raise RuntimeError(
"Unable to convert model %r with option zipmap=%r." % (
model.__class__.__name__, zipmap))

# zipmap is True
zipped = []
for i in range(len(model.estimators_)):
output_name = scope.declare_local_variable(
"multi_output_%d" % i, guessed_output_type)
sequence_at = scope.declare_local_operator('SklearnSequenceAt')
sequence_at.inputs.append(proba)
sequence_at.outputs.append(output_name)
sequence_at.index = i
zipped.extend(
_apply_zipmap(
zipmap, scope, model.estimators_[i],
guessed_output_type, sequence_at.outputs))

sequence_build = scope.declare_local_operator('SklearnSequenceConstruct')
sequence_build.inputs.extend(zipped)
proba_zipped = scope.declare_local_variable(
"probabilities",
SequenceType(DictionaryType(label_type, guessed_output_type)))
sequence_build.outputs.append(proba_zipped)
return sequence_build.outputs


def _parse_sklearn_gaussian_process(scope, model, inputs, custom_parsers=None):
Expand Down
2 changes: 2 additions & 0 deletions skl2onnx/operator_converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from . import ransac_regressor
from . import replace_op
from . import scaler_op
from . import sequence
from . import sgd_classifier
from . import stacking
from . import support_vector_machines
Expand Down Expand Up @@ -108,6 +109,7 @@
ransac_regressor,
replace_op,
scaler_op,
sequence,
sgd_classifier,
stacking,
support_vector_machines,
Expand Down
10 changes: 5 additions & 5 deletions skl2onnx/operator_converters/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,12 @@ def convert_multi_output_classifier_converter(
op_version=op_version)
for y in y_list]

label = OnnxConcat(*label_list, axis=1, op_version=op_version,
output_names=[operator.outputs[0]])
label.add_to(scope=scope, container=container)

# probabilities
proba_list = [OnnxIdentity(y[1], op_version=op_version)
for y in y_list]
label = OnnxConcat(*label_list, axis=1, op_version=op_version,
output_names=[operator.outputs[0]])
label.add_to(scope=scope, container=container)

proba = OnnxSequenceConstruct(
*proba_list, op_version=op_version,
Expand All @@ -79,4 +78,5 @@ def convert_multi_output_classifier_converter(
convert_multi_output_regressor_converter)
register_converter('SklearnMultiOutputClassifier',
convert_multi_output_classifier_converter,
options={'nocl': [False, True]})
options={'nocl': [False, True],
'zipmap': [False, True]})
31 changes: 31 additions & 0 deletions skl2onnx/operator_converters/sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0

from ..proto import onnx_proto
from ..common._registration import register_converter
from ..common._topology import Scope, Operator
from ..common._container import ModelComponentContainer


def convert_sklearn_sequence_at(scope: Scope, operator: Operator,
container: ModelComponentContainer):
i_index = operator.index
index_name = scope.get_unique_variable_name("seq_at%d" % i_index)
container.add_initializer(
index_name, onnx_proto.TensorProto.INT64, [], [i_index])
container.add_node(
'SequenceAt', [operator.inputs[0].full_name, index_name],
operator.outputs[0].full_name,
name=scope.get_unique_operator_name('SequenceAt%d' % i_index))


def convert_sklearn_sequence_construct(scope: Scope, operator: Operator,
container: ModelComponentContainer):
container.add_node(
'SequenceConstruct', [i.full_name for i in operator.inputs],
operator.outputs[0].full_name,
name=scope.get_unique_operator_name('SequenceConstruct'))


register_converter('SklearnSequenceAt', convert_sklearn_sequence_at)
register_converter(
'SklearnSequenceConstruct', convert_sklearn_sequence_construct)
21 changes: 18 additions & 3 deletions skl2onnx/operator_converters/zip_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,24 @@ def _common_convert_sklearn_zipmap(scope: Scope, operator: Operator,

def convert_sklearn_zipmap(scope: Scope, operator: Operator,
container: ModelComponentContainer):
zipmap_attrs = _common_convert_sklearn_zipmap(scope, operator, container)
container.add_node('ZipMap', operator.inputs[1].full_name,
operator.outputs[1].full_name,
if len(operator.inputs) == 2:
zipmap_attrs = _common_convert_sklearn_zipmap(
scope, operator, container)
container.add_node('ZipMap', operator.inputs[1].full_name,
operator.outputs[1].full_name,
op_domain='ai.onnx.ml', **zipmap_attrs)
return

if hasattr(operator, 'classlabels_int64s'):
zipmap_attrs = dict(classlabels_int64s=operator.classlabels_int64s)
elif hasattr(operator, 'classlabels_strings'):
zipmap_attrs = dict(classlabels_strings=operator.classlabels_strings)
else:
raise RuntimeError(
"operator should have attribute 'classlabels_int64s' or "
"'classlabels_strings'.")
container.add_node('ZipMap', operator.inputs[0].full_name,
operator.outputs[0].full_name,
op_domain='ai.onnx.ml', **zipmap_attrs)


Expand Down
2 changes: 2 additions & 0 deletions skl2onnx/shape_calculators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from . import random_trees_embedding
from . import replace_op
from . import scaler
from . import sequence
from . import svd
from . import support_vector_machines
from . import text_vectorizer
Expand Down Expand Up @@ -80,6 +81,7 @@
random_trees_embedding,
replace_op,
scaler,
sequence,
svd,
support_vector_machines,
text_vectorizer,
Expand Down
16 changes: 16 additions & 0 deletions skl2onnx/shape_calculators/sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0

from ..common._registration import register_shape_calculator


def calculate_sklearn_sequence_at(operator):
pass


def calculate_sklearn_sequence_construct(operator):
pass


register_shape_calculator('SklearnSequenceAt', calculate_sklearn_sequence_at)
register_shape_calculator(
'SklearnSequenceConstruct', calculate_sklearn_sequence_construct)
11 changes: 7 additions & 4 deletions skl2onnx/shape_calculators/zip_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@


from ..common._registration import register_shape_calculator
from ..common.utils import check_input_and_output_numbers


def calculate_sklearn_zipmap(operator):
check_input_and_output_numbers(operator, output_count_range=2)
operator.outputs[0].type = operator.inputs[0].type.__class__(
operator.inputs[0].type.shape)
if (len(operator.inputs) != len(operator.outputs) or
len(operator.inputs) not in (1, 2)):
raise RuntimeError(
"SklearnZipMap expects the same number of inputs and outputs.")
if len(operator.inputs) == 2:
operator.outputs[0].type = operator.inputs[0].type.__class__(
operator.inputs[0].type.shape)


def calculate_sklearn_zipmap_columns(operator):
Expand Down
59 changes: 30 additions & 29 deletions tests/test_algebra_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,35 +214,36 @@ def test_model_mlp_regressor_default(self):
for opv in (1, 2, 7, 8, 9, 10, 11, 12, 13, onnx_opset_version()):
if opv is not None and opv > TARGET_OPSET:
continue
try:
onx = convert_sklearn(
model, "scikit-learn MLPRegressor",
[("input", FloatTensorType([None, X_test.shape[1]]))],
target_opset=opv)
except RuntimeError as e:
if ("is higher than the number of the "
"installed onnx package") in str(e):
continue
raise e
as_string = onx.SerializeToString()
try:
ort = InferenceSession(as_string)
except (RuntimeError, InvalidGraph, Fail) as e:
if opv in (None, 1, 2):
continue
if opv >= onnx_opset_version():
continue
if ("No suitable kernel definition found for "
"op Cast(9)") in str(e):
# too old onnxruntime
continue
raise AssertionError(
"Unable to load opv={}\n---\n{}\n---".format(
opv, onx)) from e
res_out = ort.run(None, {'input': X_test})
assert len(res_out) == 1
res = res_out[0]
assert_almost_equal(exp.ravel(), res.ravel(), decimal=4)
with self.subTest(opv=opv):
try:
onx = convert_sklearn(
model, "scikit-learn MLPRegressor",
[("input", FloatTensorType([None, X_test.shape[1]]))],
target_opset=opv)
except RuntimeError as e:
if ("is higher than the number of the "
"installed onnx package") in str(e):
continue
raise e
as_string = onx.SerializeToString()
try:
ort = InferenceSession(as_string)
except (RuntimeError, InvalidGraph, Fail) as e:
if opv in (None, 1, 2):
continue
if opv >= onnx_opset_version():
continue
if ("No suitable kernel definition found for "
"op Cast(9)") in str(e):
# too old onnxruntime
continue
raise AssertionError(
"Unable to load opv={}\n---\n{}\n---".format(
opv, onx)) from e
res_out = ort.run(None, {'input': X_test})
assert len(res_out) == 1
res = res_out[0]
assert_almost_equal(exp.ravel(), res.ravel(), decimal=4)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 813c861

Please sign in to comment.