-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement remaining missing operations in native ONNX parser #630
Implement remaining missing operations in native ONNX parser #630
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Matthew, thanks for this contribution.
The implementation looks correct to me. However, given the amount of details in the conv operation, is it possible to add some unit tests to the onnx parser?
One way might be to load a small convolutional network, set the input to constant, solve it and see if the output variable assignment matches the expectation.
Otherwise, the PR looks good to me.
Here are some unit tests that we have for the onnx parser implemented in the python API. I think those networks are small enough to be considered for the unit test.
Marabou/maraboupy/test/test_onnx.py
Lines 67 to 74 in 663835d
def test_conv_mp1(): | |
""" | |
Test a convolutional network using max pool, exported from pytorch | |
Uses Conv, Relu, MaxPool, Constant, Reshape, Transpose, | |
Matmul, and Add layers | |
""" | |
filename = "conv_mp1.onnx" | |
evaluateFile(filename, inputNames = ['X'], outputNames = ['Y']) |
Marabou/maraboupy/test/test_onnx.py
Lines 50 to 57 in 663835d
def test_KJ_TinyTaxiNet(): | |
""" | |
Test a convolutional network, exported from tensorflow | |
Uses Transpose, Conv, Add, Relu, Cast, Reshape, | |
Matmul, and Identity layers | |
""" | |
filename = "KJ_TinyTaxiNet.onnx" | |
evaluateFile(filename) |
printf( "Error: no network file provided!\n" ); | ||
throw MarabouError( MarabouError::FILE_DOESNT_EXIST, networkFilePath.ascii() ); | ||
} | ||
|
||
if ( !File::exists( networkFilePath ) ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change this line to "else if" then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer not to. I think early return points are much cleaner to read and require less mental effort. However, if you insist I can change it?
Thanks for the review!
Yup sure! I'm trying to add your example networks now, but I'm having to implement some of the missing operators such as |
Added the missing operation implementations. Will add the tests later this week. |
@anwu1219 in the end I decided it would be far better to add some proper unit tests rather than simply pushing through large networks. I've therefore added a zoo of ONNX networks each consisting of a single node (+ some constants), one for each layer type that we support. There are two layers for which I've run into problems:
Given the simplicity of the sigmoid parsing code, I don't really think it can be the fault of the ONNX parser as it's identical to that for the python ONNX parser sigmoid implementation. Given that sigmoid was already implemented already, it not working doesn't make things any worse than they previously were. I therefore propose that we file an issue for that, and this PR is ready to merge in, subject to your final approval? |
Oh I forgot to say that I updated the version of the C++ ONNX library we build against, so that we use the same version of ONNX for both the Python and the C++ backends 👍 |
Same non-deterministic test failure with the |
Hi @MatthewDaggitt , thanks a lot for the added test! The PR LGTM! |
Interesting.. Yes, please file an issue for this. Thank you! |
Also fixes a couple of bugs in the parser I found and adds a better error message for when the user forgets to provide a network file on the command line.