Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Upsample) How can I use onnx parser with opset 11 ? #284

Closed
dhkim0225 opened this issue Dec 19, 2019 · 34 comments
Closed

(Upsample) How can I use onnx parser with opset 11 ? #284

dhkim0225 opened this issue Dec 19, 2019 · 34 comments
Labels

Comments

@dhkim0225
Copy link

dhkim0225 commented Dec 19, 2019

Description

onnx-parser is basically built with ir_version 3, opset 7 (https://github.com/onnx/onnx-tensorrt/blob/master/onnx_trt_backend.cpp)

Is there any way to use onnx parser with opset 11 support ?

I mean, parser works only with opset7 version.
parser works well if I use ir4_opset7 version onnx model, but doesn't work if I use ir4_opset11 version onnx model.

It also cannot parse opset 8 and 9.

My onnx models are made by pytorch 1.4.0a.

Can I rebuild the parser by changing only the BACKEND_OPSET constant inside onnx_trt_backend.cpp?

Environment

TensorRT Version: 7.0.0
GPU Type: T4
Nvidia Driver Version: 440.33.01
CUDA Version: 10.2.89
CUDNN Version: 7.6.5
Operating System + Version: Ubuntu18.04
Python Version (if applicable): 3.6.9
TensorFlow Version (if applicable): 1.4.0
PyTorch Version (if applicable): 1.4.0a

@rmccorm4
Copy link
Collaborator

Hi @dhkim0225,

How are you parsing the ONNX model? If trtexec, please share the command. If using the API, please share the code, etc.

@dhkim0225
Copy link
Author

dhkim0225 commented Dec 20, 2019

@rmccorm4 Thank you for reply. (Really really thanks a lot.)

I tried both of them. Here's my code.

Pytorch to ONNX

def main(cfg):
    net = get_model(cfg['model'], 0, weight_file=None, verbose=cfg['eval']['verbose'])
    net.eval()

    with torch.no_grad():
        dummy_input = torch.randn(1, 3, 1920, 1920, device='cuda')
        torch_out = net(dummy_input)
        onnx.export(net, dummy_input, "./my_trt/model.onnx",
                    export_params=True,
                    verbose=False,
                    training=False,
                    input_names=None,
                    output_names=None,
                    operator_export_type=onnx.OperatorExportTypes.ONNX,
                    opset_version=11,
                    do_constant_folding=True,
                    example_outputs=torch_out,
                    strip_doc_string=True,
                    dynamic_axes=None,
                    keep_initializers_as_inputs=True)

trtexec

trtexec --onnx=model.onnx --explicitBatch

Following error messages:

...
----------------------------------------------------------------
Input filename:   model.onnx
ONNX IR version:  0.0.4
Opset version:    11
Producer name:    pytorch
Producer version: 1.3
Domain:           
Model version:    0
Doc string:       
----------------------------------------------------------------
...
[12/20/2019-11:33:54] [W] [TRT] onnx2trt_utils.cpp:198: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
While parsing node number 185 [Resize]:
ERROR: ModelImporter.cpp:124 In function parseGraph:
[5] Assertion failed: ctx->tensors().count(inputName)
[12/20/2019-11:33:54] [E] Failed to parse onnx file
[12/20/2019-11:33:54] [E] Parsing model failed
[12/20/2019-11:33:54] [E] Engine creation failed
[12/20/2019-11:33:54] [E] Engine set up failed
&&&& FAILED TensorRT.trtexec # trtexec --onnx=model.onnx --explicitBatch

Convert with API

def get_engine(mode)
    EXPLICIT_BATCH = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    with trt.Builder(TRT_LOGGER) as builder, \
            builder.create_network(EXPLICIT_BATCH) as network, \
            trt.OnnxParser(network, TRT_LOGGER) as parser:
        builder.max_batch_size = 1
        builder.fp16_mode = True if mode == 'fp16' else False
        builder.int8_mode = True if mode == 'int8' else False
        builder.max_workspace_size = 1 << 32  # 1GB:30

        with open(onnx_file_path, 'rb') as model:
            parser.parse(model.read())

        print(len(network))  # Printed output == 0. Something is wrong. 

        engine = builder.build_cuda_engine(network)
        with open(engine_file_path, "wb") as f:
            f.write(engine.serialize())
    return engine

Output messages are following,

...
[TensorRT] WARNING: onnx2trt_utils.cpp:198: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
Completed parsing of ONNX file
Building an engine from file my_trt/model.onnx; this may take a while...
[TensorRT] ERROR: Network must have at least one output
[TensorRT] ERROR: Network validation failed.

....

trtexec says assertion error with ModelImporter.cpp #L124.
That's why I think the onnx-parser only support ir_version 3.

I found an similar issue here

When I use onnxruntime, model works well without any error.

Best Regards,

@rmccorm4
Copy link
Collaborator

Hi @dhkim0225,

The appreciation is appreciated 🙂

Re: trtexec errors I'll look into it hopefully tomorrow

Re: python API, I noticed the comments when network has 0 layers - that's because parsing failed. For future reference, you can get better output about that by checking output of something like parser.get_error(0) when parsing fails. See this comment: #283 (comment)

@dhkim0225
Copy link
Author

dhkim0225 commented Dec 20, 2019

@rmccorm4

The aprreciation for aprreciation is appreciated :p

I made a new clean docker and build trt from scratch. Then, len(network) returns a right value.

Maybe this is an issue of pytorch.
pytorch/pytorch#30393

With onnxruntime package, this model works well, but when I call following test code, segmentation fault error occurs. It's really strange.

onnx_model = onnx.load(args.onnx_model_path)
onnx.checker.check_model(onnx_model)

I made my onnx model with ngc container (nvcr.io/nvidia/pytorch:19.12-py3) which contains pytorch1.4.0a

This is a part of my model. We can see 'Constant' module and can find quite similar structure at above pytorch issue.
image

Please let me know if you find anything with this issue.

Best Regards,

=============================
P.S. versions.
onnx==1.6.0
onnxruntime==1.1.0
netron==3.6.5

@rmccorm4
Copy link
Collaborator

Hey @dhkim0225,

As a potential workaround in the meantime while I look into it, just curious what happens when you use PyTorch 1.3 or 1.2 to export the model?

1.4 is pretty bleeding edge so I'm not sure if it introduced anything that might be causing issues.

@dhkim0225
Copy link
Author

dhkim0225 commented Dec 20, 2019

Well,, I didn't test with pytorch 1.2 since it supports opset version up to 10.

For pytorch 1.3.0 and 1.3.1, not only onnx model occured an error with checker but also didn't work with onnxruntime. Netron's visualization outputs of model are same as model from torch1.4.0

I'll check len(model) after parser.parse() if you want! :)

@rmccorm4
Copy link
Collaborator

rmccorm4 commented Dec 20, 2019

Sure, I'm curious what's different with 1.3 / 1.3.1:

if not parser.parse(f.read()):
    print('ERROR: Failed to parse the ONNX file.')
    for error in range(parser.num_errors):
        print(parser.get_error(error))

@dhkim0225
Copy link
Author

dhkim0225 commented Dec 23, 2019

@rmccorm4 Sorry for late comment.

I made 12 onnx models with pytorch 1.3.0 and 1.3.1. (6 onnx models per each pytorch version)

You can see my toy code here

All 12 onnx models make following error.
[TensorRT] ERROR: Network must have at least one output
[TensorRT] ERROR: Network validation failed.

All onnx models are invalid_graph (checked with onnx.checker.check_graph).

Results of parser.get_error(error) and len(network) are following

...
5
In node 3 (parseGraph): INVALID_GRAPH: Assertion failed: ctx->tensors().count(inputName)
1
LayerType.CONVOLUTION

network 0 nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

image

network 1 nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

image

network 2 nn.Upsample(scale_factor=2, mode='nearest')

image

network 3 nn.Upsample((256, 256), mode='bilinear', align_corners=False)

image

network 4 nn.Upsample((256, 256), mode='bilinear', align_corners=True)

image

network 5 nn.Upsample((256, 256), mode='nearest')

image

@rmccorm4
Copy link
Collaborator

Hi @dhkim0225,

For the onnx.check_graph(model) error, that's just because of syntax. You should actually be doing either (I think the first one is preferred):

onnx.check_model(model)

or

onnx.check_graph(model.graph)

@rmccorm4
Copy link
Collaborator

Regarding the real issue reported by TensorRT when trying to parse the model, I'm guessing it's coming from the Upsample op. I've seen a few other users experience similar difficulties, which I had hoped was fixed in TRT 7, but seems not.

Although I'm not sure if this is a TRT issue or an PyTorch/ONNX issue. I was hoping that using onnx-simplifier might help, but that errors out on these models:

root@f75a4be406e3:/mnt/reproduce-trt-issue-284# python3 -m onnxsim tmp/0.onnx tmp/0.simple.onnx
Simplifying...
2019-12-23 22:37:03.603357564 [E:onnxruntime:, sequential_executor.cc:183 Execute] Non-zero status code returned while running Resize node. Name:'' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/upsample.h:281 void onnxruntime::UpsampleBase::ScalesValidation(const std::vector<float>&, onnxruntime::UpsampleMode) const scales.size() == 2 || (scales.size() == 4 && scales[0] == 1 && scales[1] == 1) was false. 'Linear' mode and 'Cubic' mode only support 2-D inputs ('Bilinear', 'Bicubic') or 4-D inputs with the corresponding outermost 2 scale values being 1 in the Resize operator
Stacktrace:

Traceback (most recent call last):
  File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.6/dist-packages/onnxsim/__main__.py", line 38, in <module>
    main()
  File "/usr/local/lib/python3.6/dist-packages/onnxsim/__main__.py", line 31, in main
    args.input_model, check_n=args.check_n, perform_optimization=not args.skip_optimization, input_shapes=input_shapes)
  File "/usr/local/lib/python3.6/dist-packages/onnxsim/onnx_simplifier.py", line 261, in simplify
    res = forward_all(model_opt, input_shapes=input_shapes)
  File "/usr/local/lib/python3.6/dist-packages/onnxsim/onnx_simplifier.py", line 140, in forward_all
    res = forward(model, input_shapes=input_shapes)
  File "/usr/local/lib/python3.6/dist-packages/onnxsim/onnx_simplifier.py", line 132, in forward
    res = OrderedDict(zip(outputs, sess.run(outputs, inputs)))
  File "/usr/local/lib/python3.6/dist-packages/onnxruntime/capi/session.py", line 142, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Resize node. Name:'' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/upsample.h:281 void onnxruntime::UpsampleBase::ScalesValidation(const std::vector<float>&, onnxruntime::UpsampleMode) const scales.size() == 2 || (scales.size() == 4 && scales[0] == 1 && scales[1] == 1) was false. 'Linear' mode and 'Cubic' mode only support 2-D inputs ('Bilinear', 'Bicubic') or 4-D inputs with the corresponding outermost 2 scale values being 1 in the Resize operator
Stacktrace:

@rmccorm4
Copy link
Collaborator

Many people are on holiday this week and next week, but hopefully might be able to find something useful in the next couple weeks.

@dhkim0225
Copy link
Author

dhkim0225 commented Dec 23, 2019

Thank you for taking the trouble to help me.
I really do appreciate it.

I'll be waiting for new comments. Please tell me if I have anything to help you.

Sincerely yours,

@rmccorm4 rmccorm4 changed the title How can I use onnx parser with opset 11 ? (Upsample) How can I use onnx parser with opset 11 ? Dec 29, 2019
@ksnzh
Copy link

ksnzh commented Jan 3, 2020

I came into the same error when using upsample layer with pytorch 1.3.1.
Using upsamplingnearest2d with pytorch 1.2.0 works for me.

@rmccorm4
Copy link
Collaborator

rmccorm4 commented Jan 3, 2020

Per ONNX, seems to be a limitation in supported parameters for Upsample (or indirectly Resize) op:

[ONNXRuntimeError]
'Linear' mode and 'Cubic' mode only support 2-D inputs ('Bilinear', 'Bicubic') 
or 4-D inputs with the corresponding outermost 2 scale values being 1 in the Resize operator

@qizhen816
Copy link

Got same problem here.
It seems the Upsample (torch.nn.interpolate) op is giving constant Tensors after pytorch 1.2.0.
Some constant values don't have any input and then make them leaf input nodes.
I tried interpolate in pytorch 1.3.0 and 1.4.0 with new shapes or scale, it always gave me something like this:

%2 : Tensor = onnx::Constant[value=[ CPUFloatType{0} ]]()
%3 : Tensor = onnx::Constant[value= 1  3 [ CPULongType{2} ]]()
%4 : Tensor = onnx::Cast[to=7](%1)
%5 : Tensor = onnx::Concat[axis=0](%3, %4)

The opset_version here is 11, and got a TensorRT error:
pytorch 1.4.0
In node 5 (parseGraph): INVALID_GRAPH: Assertion failed: ctx->tensors().count(inputName)
Here is my test code:

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
class TestModel(nn.Module):
    def __init__(self):
        super(TestModel, self).__init__()
    def forward(self, x):
        x = F.interpolate(x, (256, 256), mode = 'bilinear')
        return x
torch_model = TestModel()
dummy_input = torch.randn((1, 3, 256, 256))
torch_out = torch.onnx.export(torch_model,
                              dummy_input,
                              'test_model.onnx',
                              verbose=True,
                              opset_version=11,)

And Netron shows:
image

@lara-hdr Hi, saw you worked with Torch and Onnx in other issues, could you please help analysing this problem? It bothers me for days :(
Thanks to @ksnzh, I'm now using pytorch 1.2.0 with upsamplingnearest2d, the convert is working but with some deviation.

@qizhen816
Copy link

@ksnzh @dhkim0225 @rmccorm4 What's your onnx version? Mine was 1.6.0, then I installed 1.4.0 by pip install onnx==1.4.0 with pytorch 1.3.1 and the constant magically disappeared!

@ksnzh
Copy link

ksnzh commented Jan 6, 2020 via email

@qizhen816
Copy link

老哥自己人 我发现这个问题的根源就是onnx 1.6.0 装回1.4.0线性上采样就没什么问题了

I use onnx version 1.6.0 with pytorch 1.2.0. 2020年1月6日 +0800 PM6:32 Shepherd notifications@github.com,写道:

@ksnzh @dhkim0225 @rmccorm4 What's your onnx version? Mine was 1.6.0, then I installed 1.4.0 by pip install onnx==1.4.0 with pytorch 1.3.1 and the constant magically disappeared! — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.

@lara-hdr
Copy link

lara-hdr commented Jan 6, 2020

@qizhen816, I tested your code with pytorch master (nightly) and onnx 1.6.0, and the issue seems fixed, could you confirm?

@qizhen816
Copy link

@lara-hdr It's not working... I just found out the reason my code worked previously is I used optset_version==10, so the issue may not be caused by onnx.
This is what I got using https://download.pytorch.org/whl/nightly/cu101/torch-1.4.0.dev20200106-cp35-cp35m-linux_x86_64.whl and onnx 1.6.0:

graph(%x : Float(1, 3, 256, 256),
      %12 : Long(2)):
  %2 : Tensor = onnx::Constant[value=[ CPUFloatType{0} ]]()
  %3 : Tensor = onnx::Shape(%x)
  %4 : Tensor = onnx::Constant[value={0}]()
  %5 : Tensor = onnx::Constant[value={0}]()
  %6 : Tensor = onnx::Constant[value={2}]()
  %7 : Tensor = onnx::Slice(%3, %5, %6, %4)
  %9 : Tensor = onnx::Concat[axis=0](%7, %12)
  %10 : Tensor = onnx::Constant[value=[ CPUFloatType{0} ]]()
  %11 : Float(1, 3, 256, 256) = onnx::Resize[coordinate_transformation_mode="pytorch_half_pixel", cubic_coeff_a=-0.75, mode="linear", nearest_mode="floor"](%x, %2, %10, %9) # /home/qizhen/.local/lib/python3.5/site-packages/torch/nn/functional.py:2575:0
  return (%11)

And In node 8 (parseGraph): INVALID_GRAPH: Assertion failed: ctx->tensors().count(inputName)
However, when I switched to opset_version=10, it goes:

graph(%x : Float(1, 3, 256, 256)):
  %1 : Tensor = onnx::Constant[value= 1  1  1  1 [ CPUFloatType{4} ]]()
  %2 : Float(1, 3, 256, 256) = onnx::Resize[mode="linear"](%x, %1) # /home/qizhen/.local/lib/python3.5/site-packages/torch/nn/functional.py:2575:0
  return (%2)

And TensorRT shows no bug.
I can't use opset 10 in my project because it doesn't support align_corners and there is UserWarning. So why did this happen? It seems not a common problem but hard to fix.

@qizhen816
Copy link

@lara-hdr Sorry, I test the code as well as my project code with Torch and Onnx these days, they are all good. The UpSample issue remains with Onnx to TensorRT. Thanks for the help :)
btw, there is a small issue with onnx though: if I import torch before onnx, the onnx.checker.check_model() function will cause Segmentation fault (core dumped), otherwise it's ok.

@rmccorm4
Copy link
Collaborator

rmccorm4 commented Jan 10, 2020

Hi @dhkim0225,

So after looking into this, the original problem below:

While parsing node number 185 [Resize]:
ERROR: ModelImporter.cpp:124 In function parseGraph:
[5] Assertion failed: ctx->tensors().count(inputName)

Thanks to @kevinch-nv - Looks like the root cause of the issue was how pytorch exports opset11 resizes. Looking at the export code: https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset11.py#L177 Pytorch inserts empty "constant" layers for optional inputs and the ONNX parser did not accept this case.

It should now be fixed by this PR: onnx/onnx-tensorrt#369


To apply those changes, you can build the OSS components (https://github.com/rmccorm4/tensorrt-utils/blob/20.01/OSS/build_OSS.sh) on top of your TRT install / container like so:

wget https://raw.githubusercontent.com/rmccorm4/tensorrt-utils/20.01/OSS/build_OSS.sh
source build_OSS.sh

But there is still another issue applying to the following models:

  • 0.onnx - nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
  • 1.onnx - nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
  • 3.onnx - nn.Upsample((256, 256), mode='bilinear', align_corners=False)
  • 4.onnx - nn.Upsample((256, 256), mode='bilinear', align_corners=True)

This is because TensorRT only supports asymmetric resizing at the moment:

While parsing node number 24 [Resize]:
ERROR: /workspace/TensorRT/parsers/onnx/builtin_op_importers.cpp:2435 In function importResize:
[8] Assertion failed: (transformationMode == "asymmetric") && "This version of TensorRT only supports asymmetric resize!"
&&&& FAILED TensorRT.trtexec # trtexec --onnx=0.onnx --explicitBatch

Your model 5.onnx - nn.Upsample((256, 256), mode='nearest') now parses successfully for me:

----------------------------------------------------------------
Input filename:   5.onnx
ONNX IR version:  0.0.4
Opset version:    11
Producer name:    pytorch
Producer version: 1.3
Domain:           
Model version:    0
Doc string:       
----------------------------------------------------------------
...
&&&& PASSED TensorRT.trtexec # trtexec --onnx=5.onnx --explicitBatch

Lastly, 2.onnx - nn.Upsample(scale_factor=2, mode='nearest') is hitting a different error:

----------------------------------------------------------------
Input filename:   2.onnx
ONNX IR version:  0.0.4
Opset version:    11
Producer name:    pytorch
Producer version: 1.3
Domain:           
Model version:    0
Doc string:       
----------------------------------------------------------------
[01/10/2020-01:42:34] [W] [TRT] /workspace/TensorRT/parsers/onnx/onnx2trt_utils.cpp:232: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[01/10/2020-01:42:34] [W] [TRT] /workspace/TensorRT/parsers/onnx/onnx2trt_utils.cpp:232: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[01/10/2020-01:42:34] [W] [TRT] /workspace/TensorRT/parsers/onnx/onnx2trt_utils.cpp:232: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[01/10/2020-01:42:34] [W] [TRT] Tensor DataType is determined at build time for tensors not marked as input or output.
[01/10/2020-01:42:34] [W] [TRT] Tensor DataType is determined at build time for tensors not marked as input or output.
[01/10/2020-01:42:34] [W] [TRT] Tensor DataType is determined at build time for tensors not marked as input or output.
[01/10/2020-01:42:34] [W] [TRT] Tensor DataType is determined at build time for tensors not marked as input or output.
[01/10/2020-01:42:34] [W] [TRT] Tensor DataType is determined at build time for tensors not marked as input or output.
[01/10/2020-01:42:34] [W] [TRT] Tensor DataType is determined at build time for tensors not marked as input or output.
[01/10/2020-01:42:34] [W] [TRT] Tensor DataType is determined at build time for tensors not marked as input or output.
[01/10/2020-01:42:34] [E] [TRT] Layer: (Unnamed Layer* 9) [Unary]'s output can not be used as shape tensor.
[01/10/2020-01:42:34] [E] [TRT] Network validation failed.
[01/10/2020-01:42:34] [E] Engine creation failed
[01/10/2020-01:42:34] [E] Engine set up failed
&&&& FAILED TensorRT.trtexec # trtexec --onnx=2.onnx --explicitBatch

It seems like PyTorch generates a pretty complex graph for 2.onnx even though it should be similar to 5.onnx.

I tried using onnx-simplifier on 2.onnx and it produced a much simpler graph which TensorRT was able to parse correctly:

root@420da1448d33:/mnt/reproduce-trt-issue-284/tmp# python -m pip install onnx-simplifier
root@420da1448d33:/mnt/reproduce-trt-issue-284/tmp# python -m onnxsim 2.onnx 2.simple.onnx
Simplifying...
Ok!
root@420da1448d33:/mnt/reproduce-trt-issue-284/tmp# ls -lh
-rw-r--r-- 1 1003 1003 4.2K Jan 10 02:28 2.onnx
-rw-r--r-- 1 root root 3.7K Jan 10 02:36 2.simple.onnx

root@420da1448d33:/mnt/reproduce-trt-issue-284/tmp# trtexec --onnx=2.simple.onnx --explicitBatch
...
&&&& PASSED TensorRT.trtexec # trtexec --onnx=2.simple.onnx --explicitBatch

@daquexian
Copy link

Regarding the real issue reported by TensorRT when trying to parse the model, I'm guessing it's coming from the Upsample op. I've seen a few other users experience similar difficulties, which I had hoped was fixed in TRT 7, but seems not.

Although I'm not sure if this is a TRT issue or an PyTorch/ONNX issue. I was hoping that using onnx-simplifier might help, but that errors out on these models:

root@f75a4be406e3:/mnt/reproduce-trt-issue-284# python3 -m onnxsim tmp/0.onnx tmp/0.simple.onnx
Simplifying...
2019-12-23 22:37:03.603357564 [E:onnxruntime:, sequential_executor.cc:183 Execute] Non-zero status code returned while running Resize node. Name:'' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/upsample.h:281 void onnxruntime::UpsampleBase::ScalesValidation(const std::vector<float>&, onnxruntime::UpsampleMode) const scales.size() == 2 || (scales.size() == 4 && scales[0] == 1 && scales[1] == 1) was false. 'Linear' mode and 'Cubic' mode only support 2-D inputs ('Bilinear', 'Bicubic') or 4-D inputs with the corresponding outermost 2 scale values being 1 in the Resize operator
Stacktrace:

Traceback (most recent call last):
  File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.6/dist-packages/onnxsim/__main__.py", line 38, in <module>
    main()
  File "/usr/local/lib/python3.6/dist-packages/onnxsim/__main__.py", line 31, in main
    args.input_model, check_n=args.check_n, perform_optimization=not args.skip_optimization, input_shapes=input_shapes)
  File "/usr/local/lib/python3.6/dist-packages/onnxsim/onnx_simplifier.py", line 261, in simplify
    res = forward_all(model_opt, input_shapes=input_shapes)
  File "/usr/local/lib/python3.6/dist-packages/onnxsim/onnx_simplifier.py", line 140, in forward_all
    res = forward(model, input_shapes=input_shapes)
  File "/usr/local/lib/python3.6/dist-packages/onnxsim/onnx_simplifier.py", line 132, in forward
    res = OrderedDict(zip(outputs, sess.run(outputs, inputs)))
  File "/usr/local/lib/python3.6/dist-packages/onnxruntime/capi/session.py", line 142, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Resize node. Name:'' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/upsample.h:281 void onnxruntime::UpsampleBase::ScalesValidation(const std::vector<float>&, onnxruntime::UpsampleMode) const scales.size() == 2 || (scales.size() == 4 && scales[0] == 1 && scales[1] == 1) was false. 'Linear' mode and 'Cubic' mode only support 2-D inputs ('Bilinear', 'Bicubic') or 4-D inputs with the corresponding outermost 2 scale values being 1 in the Resize operator
Stacktrace:

I'm the author of both ONNX resize op in opset 11 and onnx-simplifier. Could you please send your model to my email daquexian566@gmail.com so that I can try to give some help? Thanks!

@kaishian
Copy link

@daquexian Thank you for your attention on this problem.

Here is my environment:
pytorch : 1.4.0a0 + 46f32e1
onnx : 1.6.0
onnx2trt: 7.0
tensorrt: 7.0.0.11

Here are the details when exporting the model to onnx:
log file

This is the onnx model file (opset 11):
model file

Hope this information is helpful for you.

@rmccorm4
Copy link
Collaborator

rmccorm4 commented Jan 25, 2020

Per the original issue "(Upsample) How can I use onnx parser with opset 11?", this has been solved in upstream ONNX parser per above posts.

However, another open issue from this thread is:

[8] Assertion failed: (transformationMode == "asymmetric") && "This version of TensorRT only supports asymmetric resize!"

If you need this, please open a separate RFE and comment there on your use cases to support it, the more info the better.

@cloudrivers
Copy link

@rmccorm4 meet error
[8] Assertion failed: get_shape_size(new_shape) == get_shape_size(tensor.getDimensions())

@wep21
Copy link

wep21 commented Feb 4, 2020

I had the same error as @rmccorm4 post.
[2020-02-04 11:50:49 ERROR] Layer: (Unnamed Layer* 279) [Unary]'s output can not be used as shape tensor.
[2020-02-04 11:50:49 ERROR] Network validation failed.
terminate called after throwing an instance of 'std::runtime_error'
what(): Failed to create object

Here is my environment:
pytorch : 1.4.0
onnx : 1.6.0
onnx2trt: master(source build)
tensorrt: 7.0.0.11
I executed ./onnx2trt yolov3.onnx -o my_engine.trt.
It contains nn.Upsample(scale_factor=2, mode='nearest').

@ycchanau
Copy link

ycchanau commented Mar 8, 2020

Got same problem here.
It seems the Upsample (torch.nn.interpolate) op is giving constant Tensors after pytorch 1.2.0.
Some constant values don't have any input and then make them leaf input nodes.
I tried interpolate in pytorch 1.3.0 and 1.4.0 with new shapes or scale, it always gave me something like this:

%2 : Tensor = onnx::Constant[value=[ CPUFloatType{0} ]]()
%3 : Tensor = onnx::Constant[value= 1  3 [ CPULongType{2} ]]()
%4 : Tensor = onnx::Cast[to=7](%1)
%5 : Tensor = onnx::Concat[axis=0](%3, %4)

The opset_version here is 11, and got a TensorRT error:
pytorch 1.4.0
In node 5 (parseGraph): INVALID_GRAPH: Assertion failed: ctx->tensors().count(inputName)
Here is my test code:

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
class TestModel(nn.Module):
    def __init__(self):
        super(TestModel, self).__init__()
    def forward(self, x):
        x = F.interpolate(x, (256, 256), mode = 'bilinear')
        return x
torch_model = TestModel()
dummy_input = torch.randn((1, 3, 256, 256))
torch_out = torch.onnx.export(torch_model,
                              dummy_input,
                              'test_model.onnx',
                              verbose=True,
                              opset_version=11,)

And Netron shows:
image

@lara-hdr Hi, saw you worked with Torch and Onnx in other issues, could you please help analysing this problem? It bothers me for days :(
Thanks to @ksnzh, I'm now using pytorch 1.2.0 with upsamplingnearest2d, the convert is working but with some deviation.

@qizhen816 Hi, I have the same problem as you. Have you solved it now? Would you mind telling me how you solve it?

@qizhen816
Copy link

@ycchanau Sorry for the delay, my onnx problem ends with torch 1.4.0. At first the error occurs with onnx so I tried to make the node look normal. But after testing with onnx-runtime it seems OK. at last, it turned out that TensorRT didn't supprt oonx bilinear upsample op, so I gave up and used other boost libraries.

@ycchanau
Copy link

@ycchanau Sorry for the delay, my onnx problem ends with torch 1.4.0. At first the error occurs with onnx so I tried to make the node look normal. But after testing with onnx-runtime it seems OK. at last, it turned out that TensorRT didn't supprt oonx bilinear upsample op, so I gave up and used other boost libraries.

@qizhen816 Thanks so much for your reply. May I know which boost libraries you are using now? I have been stuck in TensorRt for a week.

@qizhen816
Copy link

@ycchanau lol, I'm from China so I tried 2 inference engines from Alibaba and Tencent, they are (MNN)[https://github.com/alibaba/MNN] and (NCNN)[https://github.com/Tencent/ncnn]. They both are repaidlly developing with full support on Windows, Linux and Mobile devices. In my opinion the first one is better, the most important reason is bilinear upsample works perfectly, haha.

@kedardg
Copy link

kedardg commented Apr 10, 2020

@ycchanau Sorry for the delay, my onnx problem ends with torch 1.4.0. At first the error occurs with onnx so I tried to make the node look normal. But after testing with onnx-runtime it seems OK. at last, it turned out that TensorRT didn't supprt oonx bilinear upsample op, so I gave up and used other boost libraries.

You can try onnxsimplifier. The simplified onnx file should work for you.

@kingyj7
Copy link

kingyj7 commented Jun 12, 2020

If I undertand correctly, even the latest tensorRT(7.0 and 7.1) not supports bilinear upsample now?

@jovialio
Copy link

jovialio commented Jan 8, 2021

Raised this issue under onnx-tensorrt. Maybe this might help if you are having issues with interpolation.

onnx/onnx-tensorrt#615

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests