diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index c28b2459..8483cb70 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -62,7 +62,7 @@ class TestWrapper(unittest.TestCase): def test_keras_with_tf2onnx(self): try: import keras2onnx - except ImportError: + except (ImportError, AssertionError): warnings.warn("keras2onnx or one of its dependencies is missing.") return from keras2onnx.proto import keras