Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX Flatten support #243

Merged
merged 6 commits into from
May 7, 2020
Merged

Conversation

samuelemarro
Copy link
Contributor

Torch converts Flatten-like operations to ONNX.Flatten, even if .reshape() is explicitly used. This pull request should add support for ONNX.Flatten. It is based on the implementation for ONNX.Reshape.

@kjulian3
Copy link
Collaborator

Are you able to share an onnx network that uses Flatten? I haven't been able to create one because the pytorch networks I create have the torch flatten-like layers exported to onnx.Reshape.

Looking at the ONNX Flatten documentation, the output of onnx Flatten should be a 2D matrix with shape (d_0 X d_1 X ... d_{axis-1}, d_{axis} X d_{axis+1} X ...d_n). The proposed changes seem to keep the first flattenUpTo dimensions and flatten the remaining. Also, the ONNX documentation says that Flatten should only have one input (the input array), and an attribute "axis" that decides which dimension should be used in the above formula to create a 2D matrix. The code as written uses a second input instead of an attribute.

@samuelemarro
Copy link
Contributor Author

This snippet converts a simple FeedForward network and prints the operation types. The network is converted to a Flatten operation and two GEMM operations.

import onnx
import torch
import torch.nn as nn

n_input = 28 * 28
n_hidden = 10
n_output = 10

batch_size = 10

simple_network = nn.Sequential(
    nn.Linear(n_input, n_hidden),
    nn.Linear(n_hidden, n_output)
)

class ReshapedNetwork(nn.Module):
    def __init__(self, network):
        super().__init__()
        self.network = network

    def forward(self, x):
        x = x.reshape(batch_size, -1)
        return self.network(x)

simple_network = ReshapedNetwork(simple_network)

simple_network.eval()

onnx_path = 'exported_simple_network.onnx'

dummy_input = torch.rand((batch_size, 28, 28), requires_grad=True)

# Export the model
torch.onnx.export(simple_network,
                  dummy_input,
                  onnx_path,
                  export_params=True,
                  opset_version=11,
                  do_constant_folding=True,
                  input_names = ['input'],
                  output_names = ['output'],
                  dynamic_axes={'input' : {0 : 'batch_size'},
                                'output' : {0 : 'batch_size'}})

onnx_model = onnx.load(onnx_path)

for node in onnx_model.graph.node:
    print(node.op_type)

Apparently it's highly sensitive to the parameters used for .reshape(). For example, .reshape(x.shape[0], -1) is converted to Shape+Constant+Gather+Unsqueeze+Concat+Reshape, while .reshape(batch_size, -1) is converted to Flatten. nn.Flatten is converted to Flatten, as expected.

For what concerns the documentation, it looks like I completely misinterpreted the description of the implementation. I'll fix it asap.

@samuelemarro
Copy link
Contributor Author

I think I've fixed the implementation for flatten(), could you take a look?

@kjulian3
Copy link
Collaborator

Thanks for the torch script, that helps a lot. Funny enough, torch's .reshape(batch_size, -1) is converted to onnx.Flatten, but torch's .flatten() is converted to onnx.Reshape.

I had a look and have just a few small comments:

  1. According to documentation, the axis could be 0 or the maximum dimension, which means lines 345/346 could be np.prod of empty arrays. This produces floating point "1.0", which breaks the reshape operation since it isn't an int. I'd recommend putting the np.prod(...) inside an int(...) function to make sure that dimension1/dimension2 are always ints.
  2. For line 348, why not just do "self.shapeMap[nodeName] = newShape"?
  3. The comments on line 330 and 334 still refer to "reshape" instead of "flatten". I think it would also be good to mention somewhere that flatten reshapes the input to a 2D array. That is surprising to me since I would assume flatten produces a 1D array like numpy flatten, so it would be great to add a comment so future developers are not confused why we're returning a 2D array and wonder if it's a bug.

Besides these small changes it looks good! I was able to verify that Marabou produces the same output as ONNX with networks that use the ONNX.Flatten operation. Thanks for your work!

@samuelemarro
Copy link
Contributor Author

Third time's the charm, hopefully. Thank you for your help!

@kjulian3
Copy link
Collaborator

Looks good to me, thanks!

@kjulian3
Copy link
Collaborator

I made a couple of small changes. The main change is that ONNX.Flatten uses 1 as the default axis, so I changed the code to use the default instead of giving an error if the axis attribute is not given.

@kjulian3 kjulian3 requested a review from wu-haoze April 22, 2020 20:20
@kjulian3 kjulian3 merged commit da12cb8 into NeuralNetworkVerification:master May 7, 2020
AleksandarZeljic pushed a commit to AleksandarZeljic/Marabou that referenced this pull request Oct 9, 2020
* Added Flatten support

* Fixed Flatten to reflect ONNX's documentation

* Fixed np.prod() type, added documentation.

* Default axis should be 1 if not specified

* Remove extra space

Co-authored-by: kjulian3 <kjulian3@stanford.edu>
matanost pushed a commit that referenced this pull request Nov 2, 2021
* Added Flatten support

* Fixed Flatten to reflect ONNX's documentation

* Fixed np.prod() type, added documentation.

* Default axis should be 1 if not specified

* Remove extra space

Co-authored-by: kjulian3 <kjulian3@stanford.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants