-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ONNX Flatten support #243
ONNX Flatten support #243
Conversation
Are you able to share an onnx network that uses Flatten? I haven't been able to create one because the pytorch networks I create have the torch flatten-like layers exported to onnx.Reshape. Looking at the ONNX Flatten documentation, the output of onnx Flatten should be a 2D matrix with shape (d_0 X d_1 X ... d_{axis-1}, d_{axis} X d_{axis+1} X ...d_n). The proposed changes seem to keep the first flattenUpTo dimensions and flatten the remaining. Also, the ONNX documentation says that Flatten should only have one input (the input array), and an attribute "axis" that decides which dimension should be used in the above formula to create a 2D matrix. The code as written uses a second input instead of an attribute. |
This snippet converts a simple FeedForward network and prints the operation types. The network is converted to a Flatten operation and two GEMM operations.
Apparently it's highly sensitive to the parameters used for For what concerns the documentation, it looks like I completely misinterpreted the description of the implementation. I'll fix it asap. |
I think I've fixed the implementation for |
Thanks for the torch script, that helps a lot. Funny enough, torch's .reshape(batch_size, -1) is converted to onnx.Flatten, but torch's .flatten() is converted to onnx.Reshape. I had a look and have just a few small comments:
Besides these small changes it looks good! I was able to verify that Marabou produces the same output as ONNX with networks that use the ONNX.Flatten operation. Thanks for your work! |
Third time's the charm, hopefully. Thank you for your help! |
Looks good to me, thanks! |
I made a couple of small changes. The main change is that ONNX.Flatten uses 1 as the default axis, so I changed the code to use the default instead of giving an error if the axis attribute is not given. |
* Added Flatten support * Fixed Flatten to reflect ONNX's documentation * Fixed np.prod() type, added documentation. * Default axis should be 1 if not specified * Remove extra space Co-authored-by: kjulian3 <kjulian3@stanford.edu>
* Added Flatten support * Fixed Flatten to reflect ONNX's documentation * Fixed np.prod() type, added documentation. * Default axis should be 1 if not specified * Remove extra space Co-authored-by: kjulian3 <kjulian3@stanford.edu>
Torch converts Flatten-like operations to ONNX.Flatten, even if
.reshape()
is explicitly used. This pull request should add support for ONNX.Flatten. It is based on the implementation for ONNX.Reshape.