Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Showing what Flux operations are used within a CompGraph #47

Closed
vtjeng opened this issue Mar 14, 2021 · 5 comments
Closed

Showing what Flux operations are used within a CompGraph #47

vtjeng opened this issue Mar 14, 2021 · 5 comments

Comments

@vtjeng
Copy link

vtjeng commented Mar 14, 2021

I was hoping to use ONNXmutable.jl to extract the Flux operations specified in an .onnx file. Specifically, I'd like to be able to provide an .onnx file and load in a function corresponding to a composition of Flux primitives - something like

f = Chain(
  Dense(10, 5, σ),
  Dense(5, 2),
  softmax
)

What I've tried

  • I can successfully extract either a CompGraph (or a ONNX.ModelProto() using extract).
  • I am having trouble (and I expect users will too) introspecting the internal structure of the CompGraph to understand what Flux operations are contained in the CompGraph.

Use Case

  • For verification of neural networks, I currently have implemented operations and layers that accept JuMP variables (or regular floats), producing the appropriate output (appropriately constrained JuMP variables if JuMP variables are passed in, and regular floats corresponding to forward propagation otherwise).
  • I'd like to simplify importing by supporting directly passing in onnx files (with a subset of supported operations)
    • I plan to use those Flux layers [1] to replace my custom layers, and implement functionality that enables these Flux layers to accept JuMP variables and produce the appropriate output.

[1] I'm using Flux because it seems to be the best supported framework related to neural networks in Julia, but would be open to suggestions to consider other options.

Additional Notes

The layers that I'd like implemented are a subset of the operations that are supported (in fact, I'd really like to start with just a simple feedforward network with Gemm and Relu layers).

@vtjeng vtjeng changed the title Extending Flux operations within a CompGraph Showing what Flux operations are used within a CompGraph Mar 14, 2021
@DrChainsaw
Copy link
Owner

Hi and thanks for showing interest!

The CompGraph has many methods for introspection and most of them are documented in NaiveNASflux.

Here is an example of the type of ad-hoc summarytable i usually use:

julia> cg = CompGraph("resnet18-v1-7.onnx");

# Return vertices of cg as a topologically sorted array
julia> vs = vertices(cg);

# This is due to an oversight in NaiveNASflux which I will correct one day, otherwise it will fail when it hits a non-Flux layer like the element wise additions of the resnet
julia> NaiveNASflux.layer(f) = f;

julia> [name.(vs) nin.(vs) nout.(vs) layer.(vs) map(ivs -> name.(ivs), inputs.(vs)) map(ovs -> name.(ovs), outputs.(vs))]
61×6 Matrix{Any}:
 "data"                             Any[]          3  LayerTypeWrapper(FluxConv{2}())                       Any[]                                                                   ["resnetv15_conv0_fwd"]
 "resnetv15_conv0_fwd"              [3]           64  Conv((7, 7), 3=>64)                                   ["data"]                                                                ["resnetv15_batchnorm0_fwd"]
 "resnetv15_batchnorm0_fwd"         [64]          64  BatchNorm(64, λ = relu)                               ["resnetv15_conv0_fwd"]                                                 ["resnetv15_pool0_fwd"]
 "resnetv15_pool0_fwd"              [64]          64  MaxPool((3, 3), pad = (1, 1, 1, 1), stride = (2, 2))  ["resnetv15_batchnorm0_fwd"]                                            ["resnetv15_stage1_conv0_fwd", "resnetv15_stage1__plus0"]
 "resnetv15_stage1_conv0_fwd"       [64]          64  Conv((3, 3), 64=>64)                                  ["resnetv15_pool0_fwd"]                                                 ["resnetv15_stage1_batchnorm0_fwd"]
 "resnetv15_stage1_batchnorm0_fwd"  [64]          64  BatchNorm(64, λ = relu)                               ["resnetv15_stage1_conv0_fwd"]                                          ["resnetv15_stage1_conv1_fwd"]
 "resnetv15_stage1_conv1_fwd"       [64]          64  Conv((3, 3), 64=>64)                                  ["resnetv15_stage1_batchnorm0_fwd"]                                     ["resnetv15_stage1_batchnorm1_fwd"]
 "resnetv15_stage1_batchnorm1_fwd"  [64]          64  BatchNorm(64)                                         ["resnetv15_stage1_conv1_fwd"]                                          ["resnetv15_stage1__plus0"]
 "resnetv15_stage1__plus0"          [64, 64]      64  #225                                                  ["resnetv15_pool0_fwd", "resnetv15_stage1_batchnorm1_fwd"]              ["resnetv15_stage1_activation0"]
 "resnetv15_stage1_activation0"     [64]          64  #195                                                  ["resnetv15_stage1__plus0"]                                             ["resnetv15_stage1_conv2_fwd", "resnetv15_stage1__plus1"]
 "resnetv15_stage1_conv2_fwd"       [64]          64  Conv((3, 3), 64=>64)                                  ["resnetv15_stage1_activation0"]                                        ["resnetv15_stage1_batchnorm2_fwd"]
 "resnetv15_stage1_batchnorm2_fwd"  [64]          64  BatchNorm(64, λ = relu)                               ["resnetv15_stage1_conv2_fwd"]                                          ["resnetv15_stage1_conv3_fwd"]
 "resnetv15_stage1_conv3_fwd"       [64]          64  Conv((3, 3), 64=>64)                                  ["resnetv15_stage1_batchnorm2_fwd"]                                     ["resnetv15_stage1_batchnorm3_fwd"]
 "resnetv15_stage1_batchnorm3_fwd"  [64]          64  BatchNorm(64)                                         ["resnetv15_stage1_conv3_fwd"]                                          ["resnetv15_stage1__plus1"]
 "resnetv15_stage1__plus1"          [64, 64]      64  #225                                                  ["resnetv15_stage1_activation0", "resnetv15_stage1_batchnorm3_fwd"]     ["resnetv15_stage1_activation1"]
 "resnetv15_stage1_activation1"     [64]          64  #195                                                  ["resnetv15_stage1__plus1"]                                             ["resnetv15_stage2_conv2_fwd", "resnetv15_stage2_conv0_fwd"]
 "resnetv15_stage2_conv2_fwd"       [64]         128  Conv((1, 1), 64=>128)                                 ["resnetv15_stage1_activation1"]                                        ["resnetv15_stage2_batchnorm2_fwd"]
 "resnetv15_stage2_batchnorm2_fwd"  [128]        128  BatchNorm(128)                                        ["resnetv15_stage2_conv2_fwd"]                                          ["resnetv15_stage2__plus0"]
 "resnetv15_stage2_conv0_fwd"       [64]         128  Conv((3, 3), 64=>128)                                 ["resnetv15_stage1_activation1"]                                        ["resnetv15_stage2_batchnorm0_fwd"]
 "resnetv15_stage2_batchnorm0_fwd"  [128]        128  BatchNorm(128, λ = relu)                              ["resnetv15_stage2_conv0_fwd"]                                          ["resnetv15_stage2_conv1_fwd"]
 "resnetv15_stage2_conv1_fwd"       [128]        128  Conv((3, 3), 128=>128)                                ["resnetv15_stage2_batchnorm0_fwd"]                                     ["resnetv15_stage2_batchnorm1_fwd"]
 "resnetv15_stage2_batchnorm1_fwd"  [128]        128  BatchNorm(128)                                        ["resnetv15_stage2_conv1_fwd"]                                          ["resnetv15_stage2__plus0"]
 "resnetv15_stage2__plus0"          [128, 128]   128  #225                                                  ["resnetv15_stage2_batchnorm2_fwd", "resnetv15_stage2_batchnorm1_fwd"]  ["resnetv15_stage2_activation0"]
 "resnetv15_stage2_activation0"     [128]        128  #195                                                  ["resnetv15_stage2__plus0"]                                             ["resnetv15_stage2_conv3_fwd", "resnetv15_stage2__plus1"]
 "resnetv15_stage2_conv3_fwd"       [128]        128  Conv((3, 3), 128=>128)                                ["resnetv15_stage2_activation0"]                                        ["resnetv15_stage2_batchnorm3_fwd"]
 "resnetv15_stage2_batchnorm3_fwd"  [128]        128  BatchNorm(128, λ = relu)                              ["resnetv15_stage2_conv3_fwd"]                                          ["resnetv15_stage2_conv4_fwd"]
 "resnetv15_stage2_conv4_fwd"       [128]        128  Conv((3, 3), 128=>128)                                ["resnetv15_stage2_batchnorm3_fwd"]                                     ["resnetv15_stage2_batchnorm4_fwd"]
 "resnetv15_stage2_batchnorm4_fwd"  [128]        128  BatchNorm(128)                                        ["resnetv15_stage2_conv4_fwd"]                                          ["resnetv15_stage2__plus1"]
 "resnetv15_stage2__plus1"          [128, 128]   128  #225                                                  ["resnetv15_stage2_activation0", "resnetv15_stage2_batchnorm4_fwd"]     ["resnetv15_stage2_activation1"]
 "resnetv15_stage2_activation1"     [128]        128  #195                                                  ["resnetv15_stage2__plus1"]                                             ["resnetv15_stage3_conv2_fwd", "resnetv15_stage3_conv0_fwd"]
 "resnetv15_stage3_conv2_fwd"       [128]        256  Conv((1, 1), 128=>256)                                ["resnetv15_stage2_activation1"]                                        ["resnetv15_stage3_batchnorm2_fwd"]
 "resnetv15_stage3_batchnorm2_fwd"  [256]        256  BatchNorm(256)                                        ["resnetv15_stage3_conv2_fwd"]                                          ["resnetv15_stage3__plus0"]
 "resnetv15_stage3_conv0_fwd"       [128]        256  Conv((3, 3), 128=>256)                                ["resnetv15_stage2_activation1"]                                        ["resnetv15_stage3_batchnorm0_fwd"]
 "resnetv15_stage3_batchnorm0_fwd"  [256]        256  BatchNorm(256, λ = relu)                              ["resnetv15_stage3_conv0_fwd"]                                          ["resnetv15_stage3_conv1_fwd"]
 "resnetv15_stage3_conv1_fwd"       [256]        256  Conv((3, 3), 256=>256)                                ["resnetv15_stage3_batchnorm0_fwd"]                                     ["resnetv15_stage3_batchnorm1_fwd"]
 "resnetv15_stage3_batchnorm1_fwd"  [256]        256  BatchNorm(256)                                        ["resnetv15_stage3_conv1_fwd"]                                          ["resnetv15_stage3__plus0"]
 "resnetv15_stage3__plus0"          [256, 256]   256  #225                                                  ["resnetv15_stage3_batchnorm2_fwd", "resnetv15_stage3_batchnorm1_fwd"]  ["resnetv15_stage3_activation0"]
 "resnetv15_stage3_activation0"     [256]        256  #195                                                  ["resnetv15_stage3__plus0"]                                             ["resnetv15_stage3_conv3_fwd", "resnetv15_stage3__plus1"]
 "resnetv15_stage3_conv3_fwd"       [256]        256  Conv((3, 3), 256=>256)                                ["resnetv15_stage3_activation0"]                                        ["resnetv15_stage3_batchnorm3_fwd"]
 "resnetv15_stage3_batchnorm3_fwd"  [256]        256  BatchNorm(256, λ = relu)                              ["resnetv15_stage3_conv3_fwd"]                                          ["resnetv15_stage3_conv4_fwd"]
 "resnetv15_stage3_conv4_fwd"       [256]        256  Conv((3, 3), 256=>256)                                ["resnetv15_stage3_batchnorm3_fwd"]                                     ["resnetv15_stage3_batchnorm4_fwd"]
 "resnetv15_stage3_batchnorm4_fwd"  [256]        256  BatchNorm(256)                                        ["resnetv15_stage3_conv4_fwd"]                                          ["resnetv15_stage3__plus1"]
 "resnetv15_stage3__plus1"          [256, 256]   256  #225                                                  ["resnetv15_stage3_activation0", "resnetv15_stage3_batchnorm4_fwd"]     ["resnetv15_stage3_activation1"]
 "resnetv15_stage3_activation1"     [256]        256  #195                                                  ["resnetv15_stage3__plus1"]                                             ["resnetv15_stage4_conv2_fwd", "resnetv15_stage4_conv0_fwd"]
 "resnetv15_stage4_conv2_fwd"       [256]        512  Conv((1, 1), 256=>512)                                ["resnetv15_stage3_activation1"]                                        ["resnetv15_stage4_batchnorm2_fwd"]
 "resnetv15_stage4_batchnorm2_fwd"  [512]        512  BatchNorm(512)                                        ["resnetv15_stage4_conv2_fwd"]                                          ["resnetv15_stage4__plus0"]
 "resnetv15_stage4_conv0_fwd"       [256]        512  Conv((3, 3), 256=>512)                                ["resnetv15_stage3_activation1"]                                        ["resnetv15_stage4_batchnorm0_fwd"]
 "resnetv15_stage4_batchnorm0_fwd"  [512]        512  BatchNorm(512, λ = relu)                              ["resnetv15_stage4_conv0_fwd"]                                          ["resnetv15_stage4_conv1_fwd"]
 "resnetv15_stage4_conv1_fwd"       [512]        512  Conv((3, 3), 512=>512)                                ["resnetv15_stage4_batchnorm0_fwd"]                                     ["resnetv15_stage4_batchnorm1_fwd"]
 "resnetv15_stage4_batchnorm1_fwd"  [512]        512  BatchNorm(512)                                        ["resnetv15_stage4_conv1_fwd"]                                          ["resnetv15_stage4__plus0"]
 "resnetv15_stage4__plus0"          [512, 512]   512  #225                                                  ["resnetv15_stage4_batchnorm2_fwd", "resnetv15_stage4_batchnorm1_fwd"]  ["resnetv15_stage4_activation0"]
 "resnetv15_stage4_activation0"     [512]        512  #195                                                  ["resnetv15_stage4__plus0"]                                             ["resnetv15_stage4_conv3_fwd", "resnetv15_stage4__plus1"]
 "resnetv15_stage4_conv3_fwd"       [512]        512  Conv((3, 3), 512=>512)                                ["resnetv15_stage4_activation0"]                                        ["resnetv15_stage4_batchnorm3_fwd"]
 "resnetv15_stage4_batchnorm3_fwd"  [512]        512  BatchNorm(512, λ = relu)                              ["resnetv15_stage4_conv3_fwd"]                                          ["resnetv15_stage4_conv4_fwd"]
 "resnetv15_stage4_conv4_fwd"       [512]        512  Conv((3, 3), 512=>512)                                ["resnetv15_stage4_batchnorm3_fwd"]                                     ["resnetv15_stage4_batchnorm4_fwd"]
 "resnetv15_stage4_batchnorm4_fwd"  [512]        512  BatchNorm(512)                                        ["resnetv15_stage4_conv4_fwd"]                                          ["resnetv15_stage4__plus1"]
 "resnetv15_stage4__plus1"          [512, 512]   512  #225                                                  ["resnetv15_stage4_activation0", "resnetv15_stage4_batchnorm4_fwd"]     ["resnetv15_stage4_activation1"]
 "resnetv15_stage4_activation1"     [512]        512  #195                                                  ["resnetv15_stage4__plus1"]                                             ["resnetv15_pool1_fwd"]
 "resnetv15_pool1_fwd"              [512]        512  #127                                                  ["resnetv15_stage4_activation1"]                                        ["flatten_170"]
 "flatten_170"                      [512]        512  Flatten(-1)                                           ["resnetv15_pool1_fwd"]                                                 ["resnetv15_dense0_fwd"]
 "resnetv15_dense0_fwd"             [512]       1000  Dense(512, 1000)                                      ["flatten_170"]                                                         Any[]

All the methods above should have appropriate docstrings.

I have thought about how to make the CompGraph print nicely, but I can't think of any good way. If you have any ideas I'm all ears.

For examining the structure I usually just use netron on an exported model. It is also possible to extract the CompGraph as a LightGraph and use GraphPlot but the results do very seldom look good so I haven't bothered to advertise this possibility.

Hope this helps!

The layers that I'd like implemented are a subset of the operations that are supported (in fact, I'd really like to start with just a simple feedforward network with Gemm and Relu layers).

Send me a list of them and I'll add them when I find the time, or even better, try to make implement them yourself using the documentation (filing issues about what does not make sense in it) and send me a PR :)

@vtjeng
Copy link
Author

vtjeng commented Mar 15, 2021

I have thought about how to make the CompGraph print nicely, but I can't think of any good way. If you have any ideas I'm all ears.

You could go for something like one of the two top answers here https://stackoverflow.com/questions/42480111/model-summary-in-pytorch. It doesn't take care of networks with skip connections, but does give a sense for the types of layers contained within. I think you actually have something quite close to this in your ad-hoc print?

For examining the structure I usually just use netron on an exported model.

In my case I'm only importing models, so I want to visualize the model in Julia to make sure ONNXmutable 'got things right'. vertices will do the trick; I was just struggling to find this function. (I actually still can't figure out where it is defined in your source code in NaiveNASFlux / NaiveNASlib !)

Send me a list of them and I'll add them when I find the time, or even better, try to make implement them yourself using the documentation (filing issues about what does not make sense in it) and send me a PR :)

I think you've actually implemented everything I need, but I'll make an issue with any additional operations I need.

@DrChainsaw
Copy link
Owner

I think you actually have something quite close to this in your ad-hoc print?

Yes, I was thinking of making something like that the show or summary method (or even an own function) but I constantly change around what I want to see so I felt it would not be immediately useful for me. I guess it could be for people who did not write the library though...

vertices will do the trick; I was just struggling to find this function.

Ha, I realize now that it is only tangentially mentioned in the docs. I will advertise it much earlier as it is a pretty important function. When I think back it was because I initially though I would build and advertise the CompGraph as a LightGraph and vertices is a LightGraphs function.

I actually still can't figure out where it is defined in your source code in NaiveNASFlux / NaiveNASlib !

Here you go :)

The definition of flatten is right above.

I think you've actually implemented everything I need, but I'll make an issue with any additional operations I need.

Awesome!

@DrChainsaw
Copy link
Owner

make sure ONNXmutable 'got things right'.

Forgot to comment on this. Fwiw, there is a tool to compare the output of a CompGraph with an the output from onnxruntime that is being used in the unit tests. With a little bit of effort one can make use of it to make a final verification that it really was the same model that materialized (assuming onnxruntime got it right ofc :) ).

Not the most fun thing when they don't produce the same output, but if they don't then thats an issue here that becomes my problem :)

@DrChainsaw
Copy link
Owner

Better late than never I guess.

NaiveNASlib 2.1.0 uses PrettyTables to show a summary like the table above by default.

Default show image.png

The show just uses the more customizeable function graphsummary:

With graphsummary image.png

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants