Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve stability of unit tests #760

Merged
merged 1 commit into from
Oct 2, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 114 additions & 103 deletions tests/test_algebra_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import numpy as np
import onnx
from onnx.defs import onnx_opset_version
import onnxruntime
from onnxruntime import InferenceSession
from onnxruntime import InferenceSession, __version__ as ort_version
try:
from onnxruntime.capi.onnxruntime_pybind11_state import (
InvalidGraph, Fail, InvalidArgument)
InvalidGraph, Fail, InvalidArgument, NotImplemented)
except ImportError:
InvalidGraph = RuntimeError
InvalidArgument = RuntimeError
Fail = RuntimeError
NotImplemented = RuntimeError
try:
# scikit-learn >= 0.22
from sklearn.utils._testing import ignore_warnings
Expand All @@ -27,6 +27,9 @@
from test_utils import TARGET_OPSET


ort_version = ort_version.split('+')[0]


class TestOnnxOperatorsToOnnx(unittest.TestCase):

@unittest.skipIf(StrictVersion(onnx.__version__) < StrictVersion("1.4.0"),
Expand All @@ -46,61 +49,63 @@ def generate_onnx_graph(opv):
target_opset=opv)
return onx, (node, out, last)

for opv in ({'': 10}, 9, 10, 11, 12, TARGET_OPSET):
if isinstance(opv, dict):
if opv[''] > get_latest_tested_opset_version():
continue
elif opv is not None and opv > get_latest_tested_opset_version():
continue
for i, nbnode in enumerate((1, 2, 3, 100)):
onx, nodes = generate_onnx_graph(opv=opv)
if opv == {'': 10}:
for im in onx.opset_import:
if im.version > 10:
raise AssertionError(
"Wrong final opset\nopv={}\n{}".format(
opv, onx))
else:
for im in onx.opset_import:
if im.version > opv:
raise AssertionError(
"Wrong final opset\nopv={}\n{}".format(
opv, onx))
as_string = onx.SerializeToString()
try:
ort = InferenceSession(as_string)
except (InvalidGraph, InvalidArgument) as e:
if (isinstance(opv, dict) and
opv[''] >= onnx_opset_version()):
continue
if (isinstance(opv, int) and
opv >= onnx_opset_version()):
for opv in [{'': 10}] + list(range(9, TARGET_OPSET + 1)):
with self.subTest(opv=opv):
if isinstance(opv, dict):
if opv[''] > get_latest_tested_opset_version():
continue
raise AssertionError(
"Unable to load opv={}\n---\n{}\n---".format(
opv, onx)) from e
X = (np.ones((1, 5)) * nbnode).astype(np.float32)
res_out = ort.run(None, {'X1': X})
assert len(res_out) == 1
res = res_out[0]
self.assertEqual(res.shape, (1, 1))
inputs = None
expected = [[('Ad_C0', FloatTensorType(shape=[]))],
[('Li_Y0', FloatTensorType(shape=[]))],
[('Y', FloatTensorType(shape=[]))]]
for i, node in enumerate(nodes):
shape = node.get_output_type_inference(inputs)
self.assertEqual(len(shape), 1)
if isinstance(shape[0], tuple):
self.assertEqual(str(expected[i]), str(shape))
elif (opv is not None and
opv > get_latest_tested_opset_version()):
continue
for i, nbnode in enumerate((1, 2, 3, 100)):
onx, nodes = generate_onnx_graph(opv=opv)
if opv == {'': 10}:
for im in onx.opset_import:
if im.version > 10:
raise AssertionError(
"Wrong final opset\nopv={}\n{}".format(
opv, onx))
else:
self.assertEqual(
str(expected[i]),
str([(shape[0].onnx_name, shape[0].type)]))
inputs = shape
for im in onx.opset_import:
if im.version > opv:
raise AssertionError(
"Wrong final opset\nopv={}\n{}".format(
opv, onx))
as_string = onx.SerializeToString()
try:
ort = InferenceSession(as_string)
except (InvalidGraph, InvalidArgument) as e:
if (isinstance(opv, dict) and
opv[''] >= onnx_opset_version()):
continue
if (isinstance(opv, int) and
opv >= onnx_opset_version()):
continue
raise AssertionError(
"Unable to load opv={}\n---\n{}\n---".format(
opv, onx)) from e
X = (np.ones((1, 5)) * nbnode).astype(np.float32)
res_out = ort.run(None, {'X1': X})
assert len(res_out) == 1
res = res_out[0]
self.assertEqual(res.shape, (1, 1))
inputs = None
expected = [[('Ad_C0', FloatTensorType(shape=[]))],
[('Li_Y0', FloatTensorType(shape=[]))],
[('Y', FloatTensorType(shape=[]))]]
for i, node in enumerate(nodes):
shape = node.get_output_type_inference(inputs)
self.assertEqual(len(shape), 1)
if isinstance(shape[0], tuple):
self.assertEqual(str(expected[i]), str(shape))
else:
self.assertEqual(
str(expected[i]),
str([(shape[0].onnx_name, shape[0].type)]))
inputs = shape

def common_test_sub_graph(self, first_input, model, options=None,
cls_type=FloatTensorType):
cls_type=FloatTensorType, start=9):
def generate_onnx_graph(opv):
dtype = np.float32 if cls_type == FloatTensorType else np.float64
node = OnnxAdd(first_input, np.array([0.1], dtype=dtype),
Expand All @@ -119,48 +124,52 @@ def generate_onnx_graph(opv):

dtype = np.float32 if cls_type == FloatTensorType else np.float64

for opv in ({'': 10}, 9, 10, 11, 12, TARGET_OPSET):
if isinstance(opv, dict):
if opv[''] > get_latest_tested_opset_version():
continue
elif opv is not None and opv > get_latest_tested_opset_version():
continue
for i, nbnode in enumerate((1, 2, 3, 100)):
onx = generate_onnx_graph(opv=opv)
if opv == {'': 10}:
for im in onx.opset_import:
if im.version > 10:
raise AssertionError(
"Wrong final opset\nopv={}\n{}".format(
opv, onx))
else:
for im in onx.opset_import:
if im.version > opv:
raise AssertionError(
"Wrong final opset\nopv={}\n{}".format(
opv, onx))
self.assertNotIn('zipmap', str(onx))
as_string = onx.SerializeToString()
try:
ort = InferenceSession(as_string)
except (InvalidGraph, InvalidArgument, Fail) as e:
if (isinstance(opv, dict) and
opv[''] >= onnx_opset_version()):
opsets = list(range(start, TARGET_OPSET + 1))
for opv in [{'': TARGET_OPSET}] + opsets:
with self.subTest(opv=opv):
if isinstance(opv, dict):
if opv[''] > get_latest_tested_opset_version():
continue
if (isinstance(opv, int) and
opv >= onnx_opset_version()):
continue
raise AssertionError(
"Unable to load opv={}\n---\n{}\n---".format(
opv, onx)) from e
X = (np.ones((1, 5)) * nbnode).astype(dtype)
res_out = ort.run(None, {'X1': X})
assert len(res_out) == 1
res = res_out[0]
if model == LogisticRegression:
self.assertEqual(res.shape, (1, 3))
else:
self.assertEqual(res.shape, (1, 1))
elif (opv is not None and
opv > get_latest_tested_opset_version()):
continue
for i, nbnode in enumerate((1, 2, 3, 100)):
onx = generate_onnx_graph(opv=opv)
if opv == {'': TARGET_OPSET}:
for im in onx.opset_import:
if im.version > TARGET_OPSET:
raise AssertionError(
"Wrong final opset\nopv={}\n{}".format(
opv, onx))
else:
for im in onx.opset_import:
if im.version > opv:
raise AssertionError(
"Wrong final opset\nopv={}\n{}".format(
opv, onx))
self.assertNotIn('zipmap', str(onx).lower())
as_string = onx.SerializeToString()
try:
ort = InferenceSession(as_string)
except (InvalidGraph, InvalidArgument, Fail,
NotImplemented) as e:
if (isinstance(opv, dict) and
opv[''] >= onnx_opset_version()):
continue
if (isinstance(opv, int) and
opv >= onnx_opset_version()):
continue
raise AssertionError(
"Unable to load opv={}\n---\n{}\n---".format(
opv, onx)) from e
X = (np.ones((1, 5)) * nbnode).astype(dtype)
res_out = ort.run(None, {'X1': X})
assert len(res_out) == 1
res = res_out[0]
if model == LogisticRegression:
self.assertEqual(res.shape, (1, 3))
else:
self.assertEqual(res.shape, (1, 1))

@unittest.skipIf(StrictVersion(onnx.__version__) < StrictVersion("1.4.0"),
reason="not available")
Expand All @@ -172,7 +181,7 @@ def test_sub_graph_tuple(self):
@unittest.skipIf(StrictVersion(onnx.__version__) < StrictVersion("1.4.0"),
reason="not available")
@unittest.skipIf(
StrictVersion(onnxruntime.__version__) < StrictVersion("1.4.0"),
StrictVersion(ort_version) < StrictVersion("1.4.0"),
reason="not available")
@ignore_warnings(category=DeprecationWarning)
def test_sub_graph_tuple_double(self):
Expand All @@ -189,7 +198,7 @@ def test_sub_graph_str(self):
@unittest.skipIf(StrictVersion(onnx.__version__) < StrictVersion("1.4.0"),
reason="not available")
@unittest.skipIf(
StrictVersion(onnxruntime.__version__) < StrictVersion("1.4.0"),
StrictVersion(ort_version) < StrictVersion("1.4.0"),
reason="not available")
@ignore_warnings(category=DeprecationWarning)
def test_sub_graph_str_double(self):
Expand All @@ -207,36 +216,38 @@ def test_sub_graph_tuple_cls(self):
@unittest.skipIf(StrictVersion(onnx.__version__) < StrictVersion("1.4.0"),
reason="not available")
@unittest.skipIf(
StrictVersion(onnxruntime.__version__) < StrictVersion("1.4.0"),
StrictVersion(ort_version) < StrictVersion("1.4.0"),
reason="not available")
@unittest.skipIf(
StrictVersion(onnxruntime.__version__) < StrictVersion("1.10.0"),
StrictVersion(ort_version) < StrictVersion("1.10.0"),
reason="ArgMax not available for double")
@ignore_warnings(category=DeprecationWarning)
def test_sub_graph_tuple_cls_double(self):
self.common_test_sub_graph(
('X1', DoubleTensorType()), LogisticRegression,
options={'zipmap': False}, cls_type=DoubleTensorType)
options={'zipmap': False}, cls_type=DoubleTensorType,
start=13)

@unittest.skipIf(StrictVersion(onnx.__version__) < StrictVersion("1.4.0"),
reason="not available")
@ignore_warnings(category=DeprecationWarning)
def test_sub_graph_str_cls(self):
self.common_test_sub_graph('X1', LogisticRegression,
{'zipmap': False})

@unittest.skipIf(StrictVersion(onnx.__version__) < StrictVersion("1.4.0"),
reason="not available")
@unittest.skipIf(
StrictVersion(onnxruntime.__version__) < StrictVersion("1.4.0"),
StrictVersion(ort_version) < StrictVersion("1.4.0"),
reason="not available")
@unittest.skipIf(
StrictVersion(onnxruntime.__version__) < StrictVersion("1.10.0"),
StrictVersion(ort_version) < StrictVersion("1.10.0"),
reason="ArgMax not available for double")
@ignore_warnings(category=DeprecationWarning)
def test_sub_graph_str_cls_double(self):
self.common_test_sub_graph(
'X1', LogisticRegression, options={'zipmap': False},
cls_type=DoubleTensorType)
cls_type=DoubleTensorType, start=13)


if __name__ == "__main__":
Expand Down