-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Showing what Flux operations are used within a CompGraph #47
Comments
Hi and thanks for showing interest! The CompGraph has many methods for introspection and most of them are documented in NaiveNASflux. Here is an example of the type of ad-hoc summarytable i usually use: julia> cg = CompGraph("resnet18-v1-7.onnx");
# Return vertices of cg as a topologically sorted array
julia> vs = vertices(cg);
# This is due to an oversight in NaiveNASflux which I will correct one day, otherwise it will fail when it hits a non-Flux layer like the element wise additions of the resnet
julia> NaiveNASflux.layer(f) = f;
julia> [name.(vs) nin.(vs) nout.(vs) layer.(vs) map(ivs -> name.(ivs), inputs.(vs)) map(ovs -> name.(ovs), outputs.(vs))]
61×6 Matrix{Any}:
"data" Any[] 3 LayerTypeWrapper(FluxConv{2}()) Any[] ["resnetv15_conv0_fwd"]
"resnetv15_conv0_fwd" [3] 64 Conv((7, 7), 3=>64) ["data"] ["resnetv15_batchnorm0_fwd"]
"resnetv15_batchnorm0_fwd" [64] 64 BatchNorm(64, λ = relu) ["resnetv15_conv0_fwd"] ["resnetv15_pool0_fwd"]
"resnetv15_pool0_fwd" [64] 64 MaxPool((3, 3), pad = (1, 1, 1, 1), stride = (2, 2)) ["resnetv15_batchnorm0_fwd"] ["resnetv15_stage1_conv0_fwd", "resnetv15_stage1__plus0"]
"resnetv15_stage1_conv0_fwd" [64] 64 Conv((3, 3), 64=>64) ["resnetv15_pool0_fwd"] ["resnetv15_stage1_batchnorm0_fwd"]
"resnetv15_stage1_batchnorm0_fwd" [64] 64 BatchNorm(64, λ = relu) ["resnetv15_stage1_conv0_fwd"] ["resnetv15_stage1_conv1_fwd"]
"resnetv15_stage1_conv1_fwd" [64] 64 Conv((3, 3), 64=>64) ["resnetv15_stage1_batchnorm0_fwd"] ["resnetv15_stage1_batchnorm1_fwd"]
"resnetv15_stage1_batchnorm1_fwd" [64] 64 BatchNorm(64) ["resnetv15_stage1_conv1_fwd"] ["resnetv15_stage1__plus0"]
"resnetv15_stage1__plus0" [64, 64] 64 #225 ["resnetv15_pool0_fwd", "resnetv15_stage1_batchnorm1_fwd"] ["resnetv15_stage1_activation0"]
"resnetv15_stage1_activation0" [64] 64 #195 ["resnetv15_stage1__plus0"] ["resnetv15_stage1_conv2_fwd", "resnetv15_stage1__plus1"]
"resnetv15_stage1_conv2_fwd" [64] 64 Conv((3, 3), 64=>64) ["resnetv15_stage1_activation0"] ["resnetv15_stage1_batchnorm2_fwd"]
"resnetv15_stage1_batchnorm2_fwd" [64] 64 BatchNorm(64, λ = relu) ["resnetv15_stage1_conv2_fwd"] ["resnetv15_stage1_conv3_fwd"]
"resnetv15_stage1_conv3_fwd" [64] 64 Conv((3, 3), 64=>64) ["resnetv15_stage1_batchnorm2_fwd"] ["resnetv15_stage1_batchnorm3_fwd"]
"resnetv15_stage1_batchnorm3_fwd" [64] 64 BatchNorm(64) ["resnetv15_stage1_conv3_fwd"] ["resnetv15_stage1__plus1"]
"resnetv15_stage1__plus1" [64, 64] 64 #225 ["resnetv15_stage1_activation0", "resnetv15_stage1_batchnorm3_fwd"] ["resnetv15_stage1_activation1"]
"resnetv15_stage1_activation1" [64] 64 #195 ["resnetv15_stage1__plus1"] ["resnetv15_stage2_conv2_fwd", "resnetv15_stage2_conv0_fwd"]
"resnetv15_stage2_conv2_fwd" [64] 128 Conv((1, 1), 64=>128) ["resnetv15_stage1_activation1"] ["resnetv15_stage2_batchnorm2_fwd"]
"resnetv15_stage2_batchnorm2_fwd" [128] 128 BatchNorm(128) ["resnetv15_stage2_conv2_fwd"] ["resnetv15_stage2__plus0"]
"resnetv15_stage2_conv0_fwd" [64] 128 Conv((3, 3), 64=>128) ["resnetv15_stage1_activation1"] ["resnetv15_stage2_batchnorm0_fwd"]
"resnetv15_stage2_batchnorm0_fwd" [128] 128 BatchNorm(128, λ = relu) ["resnetv15_stage2_conv0_fwd"] ["resnetv15_stage2_conv1_fwd"]
"resnetv15_stage2_conv1_fwd" [128] 128 Conv((3, 3), 128=>128) ["resnetv15_stage2_batchnorm0_fwd"] ["resnetv15_stage2_batchnorm1_fwd"]
"resnetv15_stage2_batchnorm1_fwd" [128] 128 BatchNorm(128) ["resnetv15_stage2_conv1_fwd"] ["resnetv15_stage2__plus0"]
"resnetv15_stage2__plus0" [128, 128] 128 #225 ["resnetv15_stage2_batchnorm2_fwd", "resnetv15_stage2_batchnorm1_fwd"] ["resnetv15_stage2_activation0"]
"resnetv15_stage2_activation0" [128] 128 #195 ["resnetv15_stage2__plus0"] ["resnetv15_stage2_conv3_fwd", "resnetv15_stage2__plus1"]
"resnetv15_stage2_conv3_fwd" [128] 128 Conv((3, 3), 128=>128) ["resnetv15_stage2_activation0"] ["resnetv15_stage2_batchnorm3_fwd"]
"resnetv15_stage2_batchnorm3_fwd" [128] 128 BatchNorm(128, λ = relu) ["resnetv15_stage2_conv3_fwd"] ["resnetv15_stage2_conv4_fwd"]
"resnetv15_stage2_conv4_fwd" [128] 128 Conv((3, 3), 128=>128) ["resnetv15_stage2_batchnorm3_fwd"] ["resnetv15_stage2_batchnorm4_fwd"]
"resnetv15_stage2_batchnorm4_fwd" [128] 128 BatchNorm(128) ["resnetv15_stage2_conv4_fwd"] ["resnetv15_stage2__plus1"]
"resnetv15_stage2__plus1" [128, 128] 128 #225 ["resnetv15_stage2_activation0", "resnetv15_stage2_batchnorm4_fwd"] ["resnetv15_stage2_activation1"]
"resnetv15_stage2_activation1" [128] 128 #195 ["resnetv15_stage2__plus1"] ["resnetv15_stage3_conv2_fwd", "resnetv15_stage3_conv0_fwd"]
"resnetv15_stage3_conv2_fwd" [128] 256 Conv((1, 1), 128=>256) ["resnetv15_stage2_activation1"] ["resnetv15_stage3_batchnorm2_fwd"]
"resnetv15_stage3_batchnorm2_fwd" [256] 256 BatchNorm(256) ["resnetv15_stage3_conv2_fwd"] ["resnetv15_stage3__plus0"]
"resnetv15_stage3_conv0_fwd" [128] 256 Conv((3, 3), 128=>256) ["resnetv15_stage2_activation1"] ["resnetv15_stage3_batchnorm0_fwd"]
"resnetv15_stage3_batchnorm0_fwd" [256] 256 BatchNorm(256, λ = relu) ["resnetv15_stage3_conv0_fwd"] ["resnetv15_stage3_conv1_fwd"]
"resnetv15_stage3_conv1_fwd" [256] 256 Conv((3, 3), 256=>256) ["resnetv15_stage3_batchnorm0_fwd"] ["resnetv15_stage3_batchnorm1_fwd"]
"resnetv15_stage3_batchnorm1_fwd" [256] 256 BatchNorm(256) ["resnetv15_stage3_conv1_fwd"] ["resnetv15_stage3__plus0"]
"resnetv15_stage3__plus0" [256, 256] 256 #225 ["resnetv15_stage3_batchnorm2_fwd", "resnetv15_stage3_batchnorm1_fwd"] ["resnetv15_stage3_activation0"]
"resnetv15_stage3_activation0" [256] 256 #195 ["resnetv15_stage3__plus0"] ["resnetv15_stage3_conv3_fwd", "resnetv15_stage3__plus1"]
"resnetv15_stage3_conv3_fwd" [256] 256 Conv((3, 3), 256=>256) ["resnetv15_stage3_activation0"] ["resnetv15_stage3_batchnorm3_fwd"]
"resnetv15_stage3_batchnorm3_fwd" [256] 256 BatchNorm(256, λ = relu) ["resnetv15_stage3_conv3_fwd"] ["resnetv15_stage3_conv4_fwd"]
"resnetv15_stage3_conv4_fwd" [256] 256 Conv((3, 3), 256=>256) ["resnetv15_stage3_batchnorm3_fwd"] ["resnetv15_stage3_batchnorm4_fwd"]
"resnetv15_stage3_batchnorm4_fwd" [256] 256 BatchNorm(256) ["resnetv15_stage3_conv4_fwd"] ["resnetv15_stage3__plus1"]
"resnetv15_stage3__plus1" [256, 256] 256 #225 ["resnetv15_stage3_activation0", "resnetv15_stage3_batchnorm4_fwd"] ["resnetv15_stage3_activation1"]
"resnetv15_stage3_activation1" [256] 256 #195 ["resnetv15_stage3__plus1"] ["resnetv15_stage4_conv2_fwd", "resnetv15_stage4_conv0_fwd"]
"resnetv15_stage4_conv2_fwd" [256] 512 Conv((1, 1), 256=>512) ["resnetv15_stage3_activation1"] ["resnetv15_stage4_batchnorm2_fwd"]
"resnetv15_stage4_batchnorm2_fwd" [512] 512 BatchNorm(512) ["resnetv15_stage4_conv2_fwd"] ["resnetv15_stage4__plus0"]
"resnetv15_stage4_conv0_fwd" [256] 512 Conv((3, 3), 256=>512) ["resnetv15_stage3_activation1"] ["resnetv15_stage4_batchnorm0_fwd"]
"resnetv15_stage4_batchnorm0_fwd" [512] 512 BatchNorm(512, λ = relu) ["resnetv15_stage4_conv0_fwd"] ["resnetv15_stage4_conv1_fwd"]
"resnetv15_stage4_conv1_fwd" [512] 512 Conv((3, 3), 512=>512) ["resnetv15_stage4_batchnorm0_fwd"] ["resnetv15_stage4_batchnorm1_fwd"]
"resnetv15_stage4_batchnorm1_fwd" [512] 512 BatchNorm(512) ["resnetv15_stage4_conv1_fwd"] ["resnetv15_stage4__plus0"]
"resnetv15_stage4__plus0" [512, 512] 512 #225 ["resnetv15_stage4_batchnorm2_fwd", "resnetv15_stage4_batchnorm1_fwd"] ["resnetv15_stage4_activation0"]
"resnetv15_stage4_activation0" [512] 512 #195 ["resnetv15_stage4__plus0"] ["resnetv15_stage4_conv3_fwd", "resnetv15_stage4__plus1"]
"resnetv15_stage4_conv3_fwd" [512] 512 Conv((3, 3), 512=>512) ["resnetv15_stage4_activation0"] ["resnetv15_stage4_batchnorm3_fwd"]
"resnetv15_stage4_batchnorm3_fwd" [512] 512 BatchNorm(512, λ = relu) ["resnetv15_stage4_conv3_fwd"] ["resnetv15_stage4_conv4_fwd"]
"resnetv15_stage4_conv4_fwd" [512] 512 Conv((3, 3), 512=>512) ["resnetv15_stage4_batchnorm3_fwd"] ["resnetv15_stage4_batchnorm4_fwd"]
"resnetv15_stage4_batchnorm4_fwd" [512] 512 BatchNorm(512) ["resnetv15_stage4_conv4_fwd"] ["resnetv15_stage4__plus1"]
"resnetv15_stage4__plus1" [512, 512] 512 #225 ["resnetv15_stage4_activation0", "resnetv15_stage4_batchnorm4_fwd"] ["resnetv15_stage4_activation1"]
"resnetv15_stage4_activation1" [512] 512 #195 ["resnetv15_stage4__plus1"] ["resnetv15_pool1_fwd"]
"resnetv15_pool1_fwd" [512] 512 #127 ["resnetv15_stage4_activation1"] ["flatten_170"]
"flatten_170" [512] 512 Flatten(-1) ["resnetv15_pool1_fwd"] ["resnetv15_dense0_fwd"]
"resnetv15_dense0_fwd" [512] 1000 Dense(512, 1000) ["flatten_170"] Any[] All the methods above should have appropriate docstrings. I have thought about how to make the CompGraph print nicely, but I can't think of any good way. If you have any ideas I'm all ears. For examining the structure I usually just use netron on an exported model. It is also possible to extract the CompGraph as a Hope this helps!
Send me a list of them and I'll add them when I find the time, or even better, try to make implement them yourself using the documentation (filing issues about what does not make sense in it) and send me a PR :) |
You could go for something like one of the two top answers here https://stackoverflow.com/questions/42480111/model-summary-in-pytorch. It doesn't take care of networks with skip connections, but does give a sense for the types of layers contained within. I think you actually have something quite close to this in your ad-hoc print?
In my case I'm only importing models, so I want to visualize the model in Julia to make sure ONNXmutable 'got things right'.
I think you've actually implemented everything I need, but I'll make an issue with any additional operations I need. |
Yes, I was thinking of making something like that the
Ha, I realize now that it is only tangentially mentioned in the docs. I will advertise it much earlier as it is a pretty important function. When I think back it was because I initially though I would build and advertise the CompGraph as a LightGraph and
The definition of
Awesome! |
Forgot to comment on this. Fwiw, there is a tool to compare the output of a CompGraph with an the output from onnxruntime that is being used in the unit tests. With a little bit of effort one can make use of it to make a final verification that it really was the same model that materialized (assuming onnxruntime got it right ofc :) ). Not the most fun thing when they don't produce the same output, but if they don't then thats an issue here that becomes my problem :) |
Better late than never I guess. NaiveNASlib 2.1.0 uses PrettyTables to show a summary like the table above by default. The show just uses the more customizeable function |
I was hoping to use
ONNXmutable.jl
to extract theFlux
operations specified in an.onnx
file. Specifically, I'd like to be able to provide an.onnx
file and load in a function corresponding to a composition ofFlux
primitives - something likeWhat I've tried
CompGraph
(or aONNX.ModelProto()
usingextract
).CompGraph
to understand what Flux operations are contained in theCompGraph
.Use Case
JuMP
variables (or regular floats), producing the appropriate output (appropriately constrainedJuMP
variables ifJuMP
variables are passed in, and regular floats corresponding to forward propagation otherwise).onnx
files (with a subset of supported operations)Flux
layers [1] to replace my custom layers, and implement functionality that enables theseFlux
layers to acceptJuMP
variables and produce the appropriate output.[1] I'm using
Flux
because it seems to be the best supported framework related to neural networks in Julia, but would be open to suggestions to consider other options.Additional Notes
The layers that I'd like implemented are a subset of the operations that are supported (in fact, I'd really like to start with just a simple feedforward network with
Gemm
andRelu
layers).The text was updated successfully, but these errors were encountered: