Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gather in Upsample problem #192

Closed
aidonchuk opened this issue Jun 6, 2019 · 19 comments
Closed

Gather in Upsample problem #192

aidonchuk opened this issue Jun 6, 2019 · 19 comments

Comments

@aidonchuk
Copy link

aidonchuk commented Jun 6, 2019

Hi! Cant export model from onnx to tensorrt.

`----------------------------------------------------------------
Input filename: model.onnx
ONNX IR version: 0.0.4
Opset version: 9
Producer name: pytorch
Producer version: 1.1
Domain:
Model version: 0
Doc string:

WARNING: ONNX model has a newer ir_version (0.0.4) than this parser was built against (0.0.3).
Parsing model
WARNING: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
Successfully casted down to INT32.
While parsing node number 69 [Gather -> "208"]:
ERROR: /home/alex/tools/onnx-tensorrt/onnx2trt_utils.hpp:335 In function convert_axis:
[8] Assertion failed: axis >= 0 && axis < nbDims
%206 : Long() = onnx::Constantvalue={2}, scope: ResNet18_OneConvDecoder/DecoderBlock[center]/Sequential[block]/Upsample[0]
%207 : Tensor = onnx::Shape(%205), scope: ResNet18_OneConvDecoder/DecoderBlock[center]/Sequential[block]/Upsample[0]
%208 : Long() = onnx::Gather[axis=0](%207, %206), scope: ResNet18_OneConvDecoder/DecoderBlock[center]/Sequential[block]/Upsample[0]
%209 : Tensor = onnx::Constantvalue={2}
%210 : Tensor = onnx::Mul(%208, %209)`

@aidonchuk
Copy link
Author

Please help!!!

@yalagamsrinivas
Copy link

I have same issue , please suggest me solution

@zimenglan-sysu-512
Copy link

any idea to solve it?

@aidonchuk
Copy link
Author

aidonchuk commented Aug 6, 2019

I'm replace UpSample with ConTanspose2D but i don't like such workaround.

@zimenglan-sysu-512
Copy link

zimenglan-sysu-512 commented Aug 8, 2019

i solve this problem by only exporting the backbone without fpn to onnx.

@tibistrat
Copy link

I also had this issue with a model coming from PyTorch. Here's an explanation of what I did to work around the problem:

This PyTorch model, when exported to ONNX, fails when importing in TensorRT because of the Gather operation:

class ShapeModel(nn.Module):
    def __init__(self):
        super(ShapeModel, self).__init__()
    def forward(self, x):
        return x.shape

ShapeDummyModel
Assertion failed: axis >= 0 && axis < nbDims

However, this model works:

class ShapeModel(nn.Module):
    def __init__(self):
        super(ShapeModel, self).__init__()
    def forward(self, x):
        return torch.tensor(x.shape)

with just a warning during export to onnx that the trace might not generalize to other inputs.

A single PyTorch upsampling by a factor of 2 gets traced like this:

class ResizeModel(nn.Module):
    def __init__(self):
        super(ResizeModel, self).__init__()
    def forward(self, x):
        return F.interpolate(x, scale_factor=(2, 2), mode='nearest')

ResizeDummyModel
in which a lot of work is done to determine the desired size of the output tensor (and in which Gather appears).

A PyTorch interpolate function will also work if you supply not the upsampling factor, but the already-known future size of your tensor. Below an upsampler for (batch_size x channels x H x W) tensors:

class ResizeModel(nn.Module):
    def __init__(self):
        super(ResizeModel, self).__init__()
    def forward(self, x):
        sh = torch.tensor(x.shape)
        return F.interpolate(x, size=(sh[2] * 2, sh[3] * 2), mode='nearest')

Which gets traced to ONNX like this:
ResizeDummyModel_workaround
thus avoiding the Gather and which functions in TensorRT.

@aidonchuk
Copy link
Author

@tibistrat you are absolulty magician! Thx a lot!

@lucasjinreal
Copy link
Contributor

@tibistrat Thanks for your explaination. But actually, with this code:

        return F.interpolate(x, size=(sh[2] * 2, sh[3] * 2), mode='nearest')

It can not onnx2trt either, it will throw such an error

[8] Assertion failed: axis >= 0 && axis < nbDims

I have tested with a similar code:

up3 = F.interpolate(output3, size=(output2.size(2), output2.size(3)), mode="nearest")
        output2 = output2 + up3

Once I commented out this line, it works. Uncomment, it will raise error.

I still don't know why it hits and how to solve it.

Any suggestions?

@backtime92
Copy link

@jinfagang Have you solved it?

@lucasjinreal
Copy link
Contributor

No, I don't know what's the reason exactly.

@kealennieh
Copy link

@jinfagang try

up3 = F.interpolate(output3, size=(int(output2.size(2)), int(output2.size(3))), mode="nearest") 

@freetown113
Copy link

@jinfagang You know your output exact size, right? You have to enter these exact values, for example:
F.interpolate(output3, size=[64,64], mode="nearest")
As well as you need to provide a distinct output shape in "forward" function for example:
x = view(1,-1,1024,1024)
as mentioned here: pytorch/pytorch#16908

@DJMeng
Copy link

DJMeng commented Nov 18, 2019

@jinfagang Have you solved it ? thanks

@DJMeng
Copy link

DJMeng commented Nov 18, 2019

@aidonchuk have you solved it as the @tibistrat saids? why i still have errors? Pls help me , thanks

@banhbaomocmeo
Copy link

@jinfagang try

up3 = F.interpolate(output3, size=(int(output2.size(2)), int(output2.size(3))), mode="nearest") 

It works perfectly, thank you

@LeviViana
Copy link

@tibistrat Thanks a lot !

@codeslord
Copy link

You may use onnx-simplifier to do the same
https://github.com/daquexian/onnx-simplifier

@daixiangzi
Copy link

I also had this issue with a model coming from PyTorch. Here's an explanation of what I did to work around the problem:

This PyTorch model, when exported to ONNX, fails when importing in TensorRT because of the Gather operation:

class ShapeModel(nn.Module):
    def __init__(self):
        super(ShapeModel, self).__init__()
    def forward(self, x):
        return x.shape

ShapeDummyModel
Assertion failed: axis >= 0 && axis < nbDims

However, this model works:

class ShapeModel(nn.Module):
    def __init__(self):
        super(ShapeModel, self).__init__()
    def forward(self, x):
        return torch.tensor(x.shape)

with just a warning during export to onnx that the trace might not generalize to other inputs.

A single PyTorch upsampling by a factor of 2 gets traced like this:

class ResizeModel(nn.Module):
    def __init__(self):
        super(ResizeModel, self).__init__()
    def forward(self, x):
        return F.interpolate(x, scale_factor=(2, 2), mode='nearest')

ResizeDummyModel
in which a lot of work is done to determine the desired size of the output tensor (and in which Gather appears).

A PyTorch interpolate function will also work if you supply not the upsampling factor, but the already-known future size of your tensor. Below an upsampler for (batch_size x channels x H x W) tensors:

class ResizeModel(nn.Module):
    def __init__(self):
        super(ResizeModel, self).__init__()
    def forward(self, x):
        sh = torch.tensor(x.shape)
        return F.interpolate(x, size=(sh[2] * 2, sh[3] * 2), mode='nearest')

Which gets traced to ONNX like this:
ResizeDummyModel_workaround
thus avoiding the Gather and which functions in TensorRT.

it look like cool!

@kevinch-nv
Copy link
Collaborator

Closing since this has been resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests