Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Sequence instead of Tensor for MultiOutputClassifier probabilities #730

Merged
merged 8 commits into from
Sep 16, 2021
15 changes: 14 additions & 1 deletion skl2onnx/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,19 @@ def _parse_sklearn_classifier(scope, model, inputs, custom_parsers=None):
return zipmap_operator.outputs


def _parse_sklearn_multi_output_classifier(scope, model, inputs,
custom_parsers=None):
alias = _get_sklearn_operator_name(type(model))
this_operator = scope.declare_local_operator(alias, model)
this_operator.inputs = inputs
label = scope.declare_local_variable("label", Int64TensorType())
proba = scope.declare_local_variable(
"probabilities", SequenceType(guess_tensor_type(inputs[0].type)))
this_operator.outputs.append(label)
this_operator.outputs.append(proba)
return this_operator.outputs


def _parse_sklearn_gaussian_process(scope, model, inputs, custom_parsers=None):
options = scope.get_options(
model, dict(return_cov=False, return_std=False))
Expand Down Expand Up @@ -683,7 +696,7 @@ def build_sklearn_parsers_map():
BayesianRidge: _parse_sklearn_bayesian_ridge,
GaussianProcessRegressor: _parse_sklearn_gaussian_process,
GridSearchCV: _parse_sklearn_grid_search_cv,
MultiOutputClassifier: _parse_sklearn_simple_model,
MultiOutputClassifier: _parse_sklearn_multi_output_classifier,
}
if ColumnTransformer is not None:
map_parser[ColumnTransformer] = _parse_sklearn_column_transformer
Expand Down
23 changes: 21 additions & 2 deletions skl2onnx/algebra/type_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# SPDX-License-Identifier: Apache-2.0

import numpy as np
try:
from pandas import DataFrame
except ImportError:
DataFrame = None
from scipy.sparse import coo_matrix
from ..proto import TensorProto, ValueInfoProto
from ..common._topology import Variable
Expand Down Expand Up @@ -79,6 +83,21 @@ def guess_initial_types(X, initial_types):
if initial_types is None:
if isinstance(X, np.ndarray):
X = X[:1]
gt = _guess_type(X)
initial_types = [('X', gt)]
gt = _guess_type(X)
initial_types = [('X', gt)]
elif DataFrame is not None and isinstance(X, DataFrame):
X = X[:1]
initial_types = []
for c in X.columns:
if isinstance(X[c].values[0], (str, np.str_)):
g = StringTensorType()
else:
g = _guess_type(X[c].values)
g.shape = [None, 1]
initial_types.append((c, g))
elif isinstance(X, list):
initial_types = X
else:
raise TypeError(
"Unexpected type %r, unable to guess type." % type(X))
return initial_types
18 changes: 13 additions & 5 deletions skl2onnx/operator_converters/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
from ..common._registration import register_converter
from ..common._topology import Scope, Operator
from ..common._container import ModelComponentContainer
from ..algebra.onnx_ops import OnnxConcat, OnnxReshape, OnnxIdentity
from ..algebra.onnx_ops import (
OnnxConcat, OnnxReshape, OnnxIdentity)
try:
from ..algebra.onnx_ops import OnnxSequenceConstruct
except ImportError:
# Available since opset 11
OnnxSequenceConstruct = None
from ..algebra.onnx_operator import OnnxSubEstimator


Expand Down Expand Up @@ -33,6 +39,10 @@ def convert_multi_output_classifier_converter(
"""
Converts a *MultiOutputClassifier* into *ONNX* format.
"""
if OnnxSequenceConstruct is None:
raise RuntimeError(
"This converter requires opset>=11.")
op_version = container.target_opset
op_version = container.target_opset
op = operator.raw_operator
inp = operator.inputs[0]
Expand All @@ -59,10 +69,8 @@ def convert_multi_output_classifier_converter(
proba_list = [OnnxIdentity(y[1], op_version=op_version)
for y in y_list]

proba = OnnxReshape(
OnnxConcat(*proba_list, axis=1, op_version=op_version),
np.array([-1, len(op.estimators_), 2], dtype=np.int64),
op_version=op_version,
proba = OnnxSequenceConstruct(
*proba_list, op_version=op_version,
output_names=[operator.outputs[1]])
proba.add_to(scope=scope, container=container)

Expand Down
11 changes: 10 additions & 1 deletion skl2onnx/shape_calculators/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@


from ..common._registration import register_shape_calculator
from ..common.utils import check_input_and_output_numbers
from ..common.data_types import SequenceType

_stack = []


def multioutput_regressor_shape_calculator(operator):
"""Shape calculator for MultiOutputRegressor"""
check_input_and_output_numbers(
operator, input_count_range=1, output_count_range=1)
i = operator.inputs[0]
o = operator.outputs[0]
N = i.get_first_dimension()
Expand All @@ -17,12 +21,17 @@ def multioutput_regressor_shape_calculator(operator):

def multioutput_classifier_shape_calculator(operator):
"""Shape calculator for MultiOutputClassifier"""
check_input_and_output_numbers(
operator, input_count_range=1, output_count_range=2)
if not isinstance(operator.outputs[1].type, SequenceType):
raise RuntimeError(
"Probabilites should be a sequence not %r."
"" % operator.outputs[1].type)
i = operator.inputs[0]
outputs = operator.outputs
N = i.get_first_dimension()
C = len(operator.raw_operator.estimators_)
outputs[0].type.shape = [N, C]
outputs[1].type.shape = [N, C, 2]


register_shape_calculator('SklearnMultiOutputRegressor',
Expand Down
20 changes: 14 additions & 6 deletions tests/test_sklearn_multi_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def test_multi_output_regressor(self):
X.astype(numpy.float32), clf, onx,
basename="SklearnMultiOutputRegressor")

@unittest.skipIf(TARGET_OPSET < 11,
reason="SequenceConstruct not available.")
def test_multi_output_classifier(self):
X, y = make_multilabel_classification(n_classes=3, random_state=0)
X = X.astype(numpy.float32)
Expand All @@ -44,9 +46,11 @@ def test_multi_output_classifier(self):
sess = InferenceSession(onx.SerializeToString())
res = sess.run(None, {'X': X})
exp_lab = clf.predict(X)
exp_prb = numpy.transpose(clf.predict_proba(X), (1, 0, 2))
exp_prb = clf.predict_proba(X)
assert_almost_equal(exp_lab, res[0])
assert_almost_equal(exp_prb, res[1], decimal=5)
self.assertEqual(len(exp_prb), len(res[1]))
for e, g in zip(exp_prb, res[1]):
assert_almost_equal(e, g, decimal=5)

# check option nocl=True
onx = to_onnx(clf, X[:1], target_opset=TARGET_OPSET,
Expand All @@ -56,9 +60,11 @@ def test_multi_output_classifier(self):
sess = InferenceSession(onx.SerializeToString())
res = sess.run(None, {'X': X})
exp_lab = clf.predict(X)
exp_prb = numpy.transpose(clf.predict_proba(X), (1, 0, 2))
exp_prb = clf.predict_proba(X)
assert_almost_equal(exp_lab, res[0])
assert_almost_equal(exp_prb, res[1], decimal=5)
self.assertEqual(len(exp_prb), len(res[1]))
for e, g in zip(exp_prb, res[1]):
assert_almost_equal(e, g, decimal=5)

# check option nocl=False
onx = to_onnx(clf, X[:1], target_opset=TARGET_OPSET,
Expand All @@ -68,9 +74,11 @@ def test_multi_output_classifier(self):
sess = InferenceSession(onx.SerializeToString())
res = sess.run(None, {'X': X})
exp_lab = clf.predict(X)
exp_prb = numpy.transpose(clf.predict_proba(X), (1, 0, 2))
exp_prb = clf.predict_proba(X)
assert_almost_equal(exp_lab, res[0])
assert_almost_equal(exp_prb, res[1], decimal=5)
self.assertEqual(len(exp_prb), len(res[1]))
for e, g in zip(exp_prb, res[1]):
assert_almost_equal(e, g, decimal=5)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sklearn_nearest_neighbour_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def test_model_knn_regressor_double_radius(self):
model, model_onnx,
basename="SklearnRadiusNeighborsRegressor64")
dump_data_and_model(
(X + 0.1).astype(numpy.float64)[:7],
(X + 10.).astype(numpy.float64)[:7],
model, model_onnx,
basename="SklearnRadiusNeighborsRegressor64")

Expand Down
51 changes: 47 additions & 4 deletions tests/test_sklearn_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from sklearn.ensemble import VotingClassifier, RandomForestClassifier
from sklearn.naive_bayes import MultinomialNB
from sklearn.svm import SVC
from skl2onnx import convert_sklearn
from skl2onnx import convert_sklearn, to_onnx
from skl2onnx.common.data_types import (
FloatTensorType,
Int64TensorType,
Expand Down Expand Up @@ -673,6 +673,8 @@ def test_pipeline_pipeline_voting_tfidf_svc(self):
assert_almost_equal(expected_proba, got[1])
assert_almost_equal(expected_label, got[0])

@unittest.skipIf(TARGET_OPSET < 11,
reason="SequenceConstruct not available")
@ignore_warnings(category=(FutureWarning, UserWarning))
def test_pipeline_pipeline_rf(self):
cat_feat = ['A', 'B']
Expand Down Expand Up @@ -714,15 +716,56 @@ def test_pipeline_pipeline_rf(self):
sess = InferenceSession(model_onnx.SerializeToString())
got = sess.run(None, {'A': data[:, :1], 'B': data[:, 1:2],
'TEXT': data[:, 2:]})
assert_almost_equal(
numpy.transpose(expected_proba, (1, 0, 2)), got[1])
self.assertEqual(len(expected_proba), len(got[1]))
for e, g in zip(expected_proba, got[1]):
assert_almost_equal(e, g, decimal=5)
assert_almost_equal(expected_label, got[0])

@unittest.skipIf(TARGET_OPSET < 11,
reason="SequenceConstruct not available")
def test_issue_712(self):
dfx = pandas.DataFrame(
{'CAT1': ['985332', '985333', '985334', '985335', '985336'],
'CAT2': ['1985332', '1985333', '1985334', '1985335', '1985336'],
'TEXT': ["abc abc", "abc def", "def ghj", "abcdef", "abc ii"]})
dfy = pandas.DataFrame(
{'REAL': [5, 6, 7, 6, 5],
'CATY': [0, 1, 0, 1, 0]})

cat_features = ['CAT1', 'CAT2']
categorical_transformer = OneHotEncoder(handle_unknown='ignore')
textual_feature = 'TEXT'
count_vect_transformer = Pipeline(steps=[
('count_vect', CountVectorizer(
max_df=0.8, min_df=0.05, max_features=1000))])
preprocessor = ColumnTransformer(
transformers=[
('cat_transform', categorical_transformer, cat_features),
('count_vector', count_vect_transformer, textual_feature)])
model_RF = RandomForestClassifier(random_state=42, max_depth=50)
rf_clf = Pipeline(steps=[
('preprocessor', preprocessor),
('classifier', MultiOutputClassifier(estimator=model_RF))])
rf_clf.fit(dfx, dfy)
expected_label = rf_clf.predict(dfx)
expected_proba = rf_clf.predict_proba(dfx)

inputs = {'CAT1': dfx['CAT1'].values.reshape((-1, 1)),
'CAT2': dfx['CAT2'].values.reshape((-1, 1)),
'TEXT': dfx['TEXT'].values.reshape((-1, 1))}
onx = to_onnx(rf_clf, dfx, target_opset=TARGET_OPSET)
sess = InferenceSession(onx.SerializeToString())

got = sess.run(None, inputs)
assert_almost_equal(expected_label, got[0])
self.assertEqual(len(expected_proba), len(got[1]))
for e, g in zip(expected_proba, got[1]):
assert_almost_equal(e, g, decimal=5)


if __name__ == "__main__":
# import logging
# logger = logging.getLogger('skl2onnx')
# logger.setLevel(logging.DEBUG)
# logging.basicConfig(level=logging.DEBUG)
# TestSklearnPipeline().test_pipeline_pipeline_rf()
unittest.main()