Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type issue when using FeatureVectorizer #959

Merged
merged 7 commits into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions skl2onnx/common/_apply_operation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
# SPDX-License-Identifier: Apache-2.0

import numpy as np
from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE
try:
from onnx.helper import np_dtype_to_tensor_dtype
except ImportError:
from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE

def np_dtype_to_tensor_dtype(dtype):
return NP_TYPE_TO_TENSOR_TYPE[dtype]

from onnxconverter_common.onnx_ops import * # noqa
from ..proto import onnx_proto

Expand Down Expand Up @@ -129,7 +136,7 @@ def apply_clip(scope, input_name, output_name, container,
else:
min = np.array(min)
container.add_initializer(
min_name, NP_TYPE_TO_TENSOR_TYPE[min.dtype],
min_name, np_dtype_to_tensor_dtype(min.dtype),
[], [min[0]])
min = min_name
if isinstance(min, str):
Expand Down Expand Up @@ -169,7 +176,7 @@ def apply_clip(scope, input_name, output_name, container,
else:
max = np.array(max)
container.add_initializer(
max_name, NP_TYPE_TO_TENSOR_TYPE[max.dtype],
max_name, np_dtype_to_tensor_dtype(max.dtype),
[], [max[0]])
max = max_name
if isinstance(max, str):
Expand Down
16 changes: 11 additions & 5 deletions skl2onnx/operator_converters/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,19 @@ def concatenate_variables(scope, variables, container, main_type=None):

# To combine all inputs, we need a FeatureVectorizer
op_type = 'FeatureVectorizer'
attrs = {
'name': scope.get_unique_operator_name(op_type),
'inputdimensions': input_dims}
attrs = {'name': scope.get_unique_operator_name(op_type),
'inputdimensions': input_dims}
# Create a variable name to capture feature vectorizer's output
concatenated_name = scope.get_unique_variable_name('concatenated')
# Set up our FeatureVectorizer
concatenated_name = scope.get_unique_variable_name('concatenated')
container.add_node(op_type, input_names, concatenated_name,
op_domain='ai.onnx.ml', **attrs)
if main_type == FloatTensorType:
return concatenated_name
# Cast output as FeatureVectorizer always produces float32.
concatenated_name_cast = scope.get_unique_variable_name(
'concatenated_cast')
container.add_node('CastLike', [concatenated_name, input_names[0]],
concatenated_name_cast)

return concatenated_name
return concatenated_name_cast
40 changes: 38 additions & 2 deletions tests/test_sklearn_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sklearn import datasets
from sklearn.calibration import CalibratedClassifierCV
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeClassifier

try:
Expand Down Expand Up @@ -54,7 +55,8 @@
from sklearn.multioutput import MultiOutputClassifier
from test_utils import (
dump_data_and_model, fit_classification_model, TARGET_OPSET,
InferenceSessionEx as InferenceSession)
InferenceSessionEx as InferenceSession,
ReferenceEvaluatorEx)
from onnxruntime import __version__ as ort_version


Expand Down Expand Up @@ -171,7 +173,9 @@ def test_combine_inputs_union_in_pipeline(self):
dump_data_and_model(
data, PipeConcatenateInput(model),
model_onnx, basename="SklearnPipelineScaler11Union")
TARGET_OPSET

Check notice

Code scanning / CodeQL

Statement has no effect

This statement has no effect.

@unittest.skipIf(TARGET_OPSET < 15, reason="uses CastLike")
@unittest.skipIf(
pv.Version(ort_version) <= pv.Version('0.4.0'),
reason="onnxruntime too old")
Expand Down Expand Up @@ -265,7 +269,11 @@ def test_pipeline_column_transformer(self):
basename="SklearnPipelineColumnTransformerPipeliner")

if __name__ == "__main__":
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
try:
from onnx.tools.net_drawer import (
GetPydotGraph, GetOpNodeProducer)
except ImportError:
return

pydot_graph = GetPydotGraph(
model_onnx.graph,
Expand Down Expand Up @@ -1035,10 +1043,38 @@ def transform(self, X):
to_onnx(model, X_in)
self.assertIn('ColumnTransformer', str(e))

@unittest.skipIf(TARGET_OPSET < 15, reason="use CastLike")
def test_feature_vectorizer_double(self):
dataset = datasets.load_diabetes(as_frame=True)
X, y = dataset.data, dataset.target
X["sexi"] = X["sex"].astype(numpy.int64)
X = X.drop("sex", axis=1)
X_train, X_test, y_train, y_test = train_test_split(X, y)
regr = Pipeline([("std", StandardScaler()),
("reg", LinearRegression())])
regr = regr.fit(X_train, y_train)
onnx_model = to_onnx(regr, X=X_train)

sess = InferenceSession(
onnx_model.SerializeToString(),
providers=["CPUExecutionProvider"])
expected = regr.predict(X_test)
names = [i.name for i in sess.get_inputs()]
feeds = {n: X_test[c].values.reshape((-1, 1))
for n, c in zip(names, X_test.columns)}
got = sess.run(None, feeds)
assert_almost_equal(expected.ravel(), got[0].ravel(), decimal=4)
if ReferenceEvaluatorEx is None:
return
ref = ReferenceEvaluatorEx(onnx_model)
got = ref.run(None, feeds)
assert_almost_equal(expected.ravel(), got[0].ravel(), decimal=4)


if __name__ == "__main__":
# import logging
# logger = logging.getLogger('skl2onnx')
# logger.setLevel(logging.DEBUG)
# logging.basicConfig(level=logging.DEBUG)
# TestSklearnPipeline().test_feature_vectorizer_double()
unittest.main(verbosity=2)
17 changes: 10 additions & 7 deletions tests/test_sklearn_random_forest_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,14 +290,17 @@ def test_model_random_forest_classifier_multi_output_int(self):

@ignore_warnings(category=FutureWarning)
def common_test_model_hgb_regressor(self, add_nan=False):
model = HistGradientBoostingRegressor(max_iter=5, max_depth=2)
rng = numpy.random.RandomState(12345)
model = HistGradientBoostingRegressor(max_iter=4, max_depth=2)
X, y = make_regression(n_features=10, n_samples=1000,
n_targets=1, random_state=42)
if add_nan:
rows = numpy.random.randint(0, X.shape[0] - 1, X.shape[0] // 3)
cols = numpy.random.randint(0, X.shape[1] - 1, X.shape[0] // 3)
rows = rng.randint(0, X.shape[0] - 1, X.shape[0] // 3)
cols = rng.randint(0, X.shape[1] - 1, X.shape[0] // 3)
X[rows, cols] = numpy.nan

X = X.astype(numpy.float32)
y = y.astype(numpy.float32)
X_train, X_test, y_train, _ = train_test_split(X, y, test_size=0.5,
random_state=42)
model.fit(X_train, y_train)
Expand All @@ -306,10 +309,10 @@ def common_test_model_hgb_regressor(self, add_nan=False):
model, "unused", [("input", FloatTensorType([None, X.shape[1]]))],
target_opset=TARGET_OPSET)
self.assertIsNotNone(model_onnx)
X_test = X_test.astype(numpy.float32)[:5]
X_test = X_test.astype(numpy.float32)[:10]
dump_data_and_model(
X_test, model, model_onnx,
basename="SklearnHGBRegressor", verbose=False)
basename=f"SklearnHGBRegressor{add_nan}", verbose=False)

@unittest.skipIf(_sklearn_version() < pv.Version('0.22.0'),
reason="missing_go_to_left is missing")
Expand Down Expand Up @@ -482,7 +485,7 @@ def test_boston_pca_rf(self):
X, y, random_state=0)
pipe = Pipeline([
('acp', PCA(n_components=3)),
('rf', RandomForestRegressor())])
('rf', RandomForestRegressor(n_estimators=100))])
pipe.fit(X_train, y_train)
X32 = X_test.astype(numpy.float32)
model_onnx = to_onnx(pipe, X32[:1], target_opset=TARGET_OPSET)
Expand Down Expand Up @@ -742,4 +745,4 @@ def test_rf_classifier_decision_path_leaf(self):


if __name__ == "__main__":
unittest.main()
unittest.main(verbosity=2)
35 changes: 19 additions & 16 deletions tests/test_utils/reference_implementation_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,23 +92,26 @@ def leaf_index_tree(self, X, tree_id):
index = self.root_index[tree_id]
while self.atts.nodes_modes[index] != "LEAF":
x = X[self.atts.nodes_featureids[index]]
rule = self.atts.nodes_modes[index]
th = self.atts.nodes_values[index]
if rule == "BRANCH_LEQ":
r = x <= th
elif rule == "BRANCH_LT":
r = x < th
elif rule == "BRANCH_GTE":
r = x >= th
elif rule == "BRANCH_GT":
r = x > th
elif rule == "BRANCH_EQ":
r = x == th
elif rule == "BRANCH_NEQ":
r = x != th
if np.isnan(x):
r = self.atts.nodes_missing_value_tracks_true[index] >= 1
else:
raise ValueError(
f"Unexpected rule {rule!r} for node index {index}.")
rule = self.atts.nodes_modes[index]
th = self.atts.nodes_values[index]
if rule == "BRANCH_LEQ":
r = x <= th
elif rule == "BRANCH_LT":
r = x < th
elif rule == "BRANCH_GTE":
r = x >= th
elif rule == "BRANCH_GT":
r = x > th
elif rule == "BRANCH_EQ":
r = x == th
elif rule == "BRANCH_NEQ":
r = x != th
else:
raise ValueError(
f"Unexpected rule {rule!r} for node index {index}.")
nid = (self.atts.nodes_truenodeids[index]
if r else self.atts.nodes_falsenodeids[index])
index = self.node_index[tree_id, nid]
Expand Down