Skip to content

Commit

Permalink
Add unit test from issue 1129
Browse files Browse the repository at this point in the history
Signed-off-by: xadupre <xadupre@microsoft.com>
  • Loading branch information
xadupre committed Oct 2, 2024
1 parent f98c75e commit 24a2d80
Showing 1 changed file with 31 additions and 17 deletions.
48 changes: 31 additions & 17 deletions tests/test_issues_2024.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ def test_issue_1129_lr(self):
from numpy.testing import assert_almost_equal
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
import skl2onnx
from onnxruntime import InferenceSession

Expand All @@ -286,30 +288,42 @@ def test_issue_1129_lr(self):
"float_column": np.random.rand(10).astype(np.float64),
"int_column": np.random.randint(0, 100, size=10).astype(np.int64),
}
x = pd.DataFrame(data)
x_ = pd.DataFrame(data)
y = np.random.binomial(1, 0.5, size=10)

# Create a test dataset with 10 rows
test_data = {
"float_column": np.random.rand(10).astype(np.float64),
"int_column": np.random.randint(0, 100, size=10).astype(np.int64),
}
x_test = pd.DataFrame(test_data)

# Select and train a model
model = LogisticRegression()
model.fit(x, y)
# Take predictions and probabilities with sklearn
sklearn_preds = model.predict(x_test)
sklearn_probs = model.predict_proba(x_test)

# Convert the model to ONNX
onnx_model = skl2onnx.to_onnx(model, x.values, options={"zipmap": False})
# Take predictions and probabilities with ONNX
sess = InferenceSession(onnx_model.SerializeToString())
onnx_prediction = sess.run(None, {"X": x_test.to_numpy()})
assert_almost_equal(sklearn_probs, onnx_prediction[1])
assert_almost_equal(sklearn_preds, onnx_prediction[0])
x_test_ = pd.DataFrame(test_data)

for cls in [LogisticRegression, DecisionTreeClassifier, RandomForestClassifier]:
with self.subTest(cls=cls):
# Select and train a model
if cls == LogisticRegression:
x = x_.astype(np.float64)
x_test = x_test_.astype(np.float64)
decimal = 10
else:
x = x_.astype(np.float32)
x_test = x_test_.astype(np.float32)
decimal = 4
model = cls()
model.fit(x, y)
# Take predictions and probabilities with sklearn
sklearn_preds = model.predict(x_test)
sklearn_probs = model.predict_proba(x_test)

# Convert the model to ONNX
onnx_model = skl2onnx.to_onnx(
model, x.values, options={"zipmap": False}
)
# Take predictions and probabilities with ONNX
sess = InferenceSession(onnx_model.SerializeToString())
onnx_prediction = sess.run(None, {"X": x_test.to_numpy()})
assert_almost_equal(sklearn_probs, onnx_prediction[1], decimal=decimal)
assert_almost_equal(sklearn_preds, onnx_prediction[0])


if __name__ == "__main__":
Expand Down

0 comments on commit 24a2d80

Please sign in to comment.