Skip to content

Commit

Permalink
Fixes #813, support unsigned class labels in MLPClassifier (#818)
Browse files Browse the repository at this point in the history
* Fix unsigned class lablels
* add uint64
  • Loading branch information
xadupre authored Feb 2, 2022
1 parent 950f7c5 commit 544570e
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 98 deletions.
2 changes: 1 addition & 1 deletion skl2onnx/operator_converters/multilayer_perceptron.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def convert_sklearn_mlp_classifier(scope: Scope, operator: Operator,
if np.issubdtype(mlp_op.classes_.dtype, np.floating):
class_type = onnx_proto.TensorProto.INT32
classes = classes.astype(np.int32)
elif np.issubdtype(mlp_op.classes_.dtype, np.signedinteger):
elif np.issubdtype(mlp_op.classes_.dtype, np.integer):
class_type = onnx_proto.TensorProto.INT32
else:
classes = np.array([s.encode('utf-8') for s in classes])
Expand Down
159 changes: 63 additions & 96 deletions tests/test_sklearn_mlp_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def ignore_warnings(category=Warning):
)


ort_version = ".".join(ort_version.split('.')[:2])


class TestSklearnMLPConverters(unittest.TestCase):
@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
Expand All @@ -52,13 +55,8 @@ def test_model_mlp_classifier_binary(self):
)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
X_test,
model,
model_onnx,
basename="SklearnMLPClassifierBinary",
allow_failure="StrictVersion("
"onnxruntime.__version__)<= StrictVersion('0.2.1')",
)
X_test, model, model_onnx,
basename="SklearnMLPClassifierBinary")

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
Expand All @@ -74,13 +72,42 @@ def test_model_mlp_classifier_multiclass_default(self):
)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
X_test,
X_test, model, model_onnx,
basename="SklearnMLPClassifierMultiClass")

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
@ignore_warnings(category=(ConvergenceWarning, FutureWarning))
def test_model_mlp_classifier_multiclass_default_uint8(self):
model, X_test = fit_classification_model(
MLPClassifier(random_state=42), 4, cls_dtype=np.uint8)
model_onnx = convert_sklearn(
model,
"scikit-learn MLPClassifier",
[("input", FloatTensorType([None, X_test.shape[1]]))],
target_opset=TARGET_OPSET
)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
X_test, model, model_onnx,
basename="SklearnMLPClassifierMultiClassU8")

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
@ignore_warnings(category=(ConvergenceWarning, FutureWarning))
def test_model_mlp_classifier_multiclass_default_uint64(self):
model, X_test = fit_classification_model(
MLPClassifier(random_state=42), 4, cls_dtype=np.uint64)
model_onnx = convert_sklearn(
model,
model_onnx,
basename="SklearnMLPClassifierMultiClass",
allow_failure="StrictVersion("
"onnxruntime.__version__)<= StrictVersion('0.2.1')",
"scikit-learn MLPClassifier",
[("input", FloatTensorType([None, X_test.shape[1]]))],
target_opset=TARGET_OPSET
)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
X_test, model, model_onnx,
basename="SklearnMLPClassifierMultiClassU64")

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
Expand All @@ -96,13 +123,8 @@ def test_model_mlp_classifier_multilabel_default(self):
)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
X_test,
model,
model_onnx,
basename="SklearnMLPClassifierMultiLabel",
allow_failure="StrictVersion("
"onnxruntime.__version__)<= StrictVersion('0.2.1')",
)
X_test, model, model_onnx,
basename="SklearnMLPClassifierMultiLabel")

@ignore_warnings(category=(ConvergenceWarning, FutureWarning))
def test_model_mlp_regressor_default(self):
Expand All @@ -116,13 +138,8 @@ def test_model_mlp_regressor_default(self):
)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
X_test,
model,
model_onnx,
basename="SklearnMLPRegressor-Dec4",
allow_failure="StrictVersion("
"onnxruntime.__version__)<= StrictVersion('0.2.1')",
)
X_test, model, model_onnx,
basename="SklearnMLPRegressor-Dec4")

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
Expand All @@ -139,13 +156,8 @@ def test_model_mlp_classifier_multiclass_identity(self):
)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
X_test,
model,
model_onnx,
basename="SklearnMLPClassifierMultiClassIdentityActivation",
allow_failure="StrictVersion("
"onnxruntime.__version__)<= StrictVersion('0.2.1')",
)
X_test, model, model_onnx,
basename="SklearnMLPClassifierMultiClassIdentityActivation")

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
Expand All @@ -162,13 +174,8 @@ def test_model_mlp_classifier_multilabel_identity(self):
)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
X_test,
model,
model_onnx,
basename="SklearnMLPClassifierMultiLabelIdentityActivation",
allow_failure="StrictVersion("
"onnxruntime.__version__)<= StrictVersion('0.2.1')",
)
X_test, model, model_onnx,
basename="SklearnMLPClassifierMultiLabelIdentityActivation")

@ignore_warnings(category=(ConvergenceWarning, FutureWarning))
def test_model_mlp_regressor_identity(self):
Expand All @@ -182,13 +189,8 @@ def test_model_mlp_regressor_identity(self):
)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
X_test,
model,
model_onnx,
basename="SklearnMLPRegressorIdentityActivation-Dec4",
allow_failure="StrictVersion("
"onnxruntime.__version__)<= StrictVersion('0.2.1')",
)
X_test, model, model_onnx,
basename="SklearnMLPRegressorIdentityActivation-Dec4")

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
Expand All @@ -204,13 +206,8 @@ def test_model_mlp_classifier_multiclass_logistic(self):
)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
X_test,
model,
model_onnx,
basename="SklearnMLPClassifierMultiClassLogisticActivation",
allow_failure="StrictVersion("
"onnxruntime.__version__)<= StrictVersion('0.2.1')",
)
X_test, model, model_onnx,
basename="SklearnMLPClassifierMultiClassLogisticActivation")

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
Expand All @@ -226,13 +223,8 @@ def test_model_mlp_classifier_multilabel_logistic(self):
)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
X_test,
model,
model_onnx,
basename="SklearnMLPClassifierMultiLabelLogisticActivation",
allow_failure="StrictVersion("
"onnxruntime.__version__)<= StrictVersion('0.2.1')",
)
X_test, model, model_onnx,
basename="SklearnMLPClassifierMultiLabelLogisticActivation")

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
Expand All @@ -248,13 +240,8 @@ def test_model_mlp_regressor_logistic(self):
)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
X_test,
model,
model_onnx,
basename="SklearnMLPRegressorLogisticActivation-Dec4",
allow_failure="StrictVersion("
"onnxruntime.__version__)<= StrictVersion('0.2.1')",
)
X_test, model, model_onnx,
basename="SklearnMLPRegressorLogisticActivation-Dec4")

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
Expand All @@ -270,13 +257,8 @@ def test_model_mlp_classifier_multiclass_tanh(self):
)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
X_test,
model,
model_onnx,
basename="SklearnMLPClassifierMultiClassTanhActivation",
allow_failure="StrictVersion("
"onnxruntime.__version__)<= StrictVersion('0.2.1')",
)
X_test, model, model_onnx,
basename="SklearnMLPClassifierMultiClassTanhActivation")

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
Expand All @@ -292,13 +274,8 @@ def test_model_mlp_classifier_multilabel_tanh(self):
)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
X_test,
model,
model_onnx,
basename="SklearnMLPClassifierMultiLabelTanhActivation",
allow_failure="StrictVersion("
"onnxruntime.__version__)<= StrictVersion('0.2.1')",
)
X_test, model, model_onnx,
basename="SklearnMLPClassifierMultiLabelTanhActivation")

@ignore_warnings(category=(ConvergenceWarning, FutureWarning))
def test_model_mlp_regressor_tanh(self):
Expand All @@ -312,13 +289,8 @@ def test_model_mlp_regressor_tanh(self):
)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
X_test,
model,
model_onnx,
basename="SklearnMLPRegressorTanhActivation-Dec4",
allow_failure="StrictVersion("
"onnxruntime.__version__)<= StrictVersion('0.2.1')",
)
X_test, model, model_onnx,
basename="SklearnMLPRegressorTanhActivation-Dec4")

@ignore_warnings(category=(ConvergenceWarning, FutureWarning))
def test_model_mlp_regressor_bool(self):
Expand All @@ -332,13 +304,8 @@ def test_model_mlp_regressor_bool(self):
)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
X_test,
model,
model_onnx,
basename="SklearnMLPRegressorBool",
allow_failure="StrictVersion("
"onnxruntime.__version__)<= StrictVersion('0.2.1')",
)
X_test, model, model_onnx,
basename="SklearnMLPRegressorBool")

@unittest.skipIf(
StrictVersion(ort_version) < StrictVersion('1.0.0'),
Expand Down
4 changes: 3 additions & 1 deletion tests/test_utils/tests_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,14 @@ def fit_classification_model(model, n_classes, is_int=False,
pos_features=False, label_string=False,
random_state=42, is_bool=False,
n_features=20, n_redundant=None,
n_repeated=None):
n_repeated=None, cls_dtype=None):
X, y = make_classification(
n_classes=n_classes, n_features=n_features, n_samples=250,
random_state=random_state, n_informative=min(7, n_features),
n_redundant=n_redundant or min(2, n_features - min(7, n_features)),
n_repeated=n_repeated or 0)
if cls_dtype is not None:
y = y.astype(cls_dtype)
if label_string:
y = numpy.array(['cl%d' % cl for cl in y])
X = X.astype(numpy.int64) if is_int or is_bool else X.astype(numpy.float32)
Expand Down

0 comments on commit 544570e

Please sign in to comment.