Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gertting different output sizes when using exported torchscript #1562

Closed
dagap opened this issue Nov 30, 2020 · 6 comments
Closed

Gertting different output sizes when using exported torchscript #1562

dagap opened this issue Nov 30, 2020 · 6 comments
Labels
bug Something isn't working Stale

Comments

@dagap
Copy link

dagap commented Nov 30, 2020

I am running some dummy data through the model as follows:

import torch

image = torch.zeros(1, 3, 640, 640).to('cuda:0')
model = torch.load('yolov5.pt', map_location=map_location)['model'].float().fuse().eval())
pred = model(image)[0]

The shape of the output prediction is [1, 192, 80, 80]

I have exported the model to torchscript using the export.py script and the usage is as follows:

import torch
model = torch.jit.load('yolov5m.torchscript.pt', map_location='cuda:0').eval()
pred = m(img)[0]

The output of this is of shape [1, 3, 80, 80, 85].

So the model definitely gets exported and I can use it for inference but really no clue why the output shape is different. Maybe this is due to laye fusion, so I tried to do something like to replicate this fusion:

model = torch.jit.load('yolov5m.torchscript.pt', map_location='cuda:0').float().fuse().eval()

but this results in RecursiveScriptModule' object has no attribute 'fuse'

I was wondering if there is something else one needs to do to make the two approaches consistent.

@dagap dagap added the bug Something isn't working label Nov 30, 2020
@github-actions
Copy link
Contributor

github-actions bot commented Nov 30, 2020

Hello @dagap, thank you for your interest in 🚀 YOLOv5! Please visit our ⭐️ Tutorials to get started, where you can find quickstart guides for simple tasks like Custom Data Training all the way to advanced concepts like Hyperparameter Evolution.

If this is a 🐛 Bug Report, please provide screenshots and minimum viable code to reproduce your issue, otherwise we can not help you.

If this is a custom training ❓ Question, please provide as much information as possible, including dataset images, training logs, screenshots, and a public link to online W&B logging if available.

For business inquiries or professional support requests please visit https://www.ultralytics.com or email Glenn Jocher at glenn.jocher@ultralytics.com.

Requirements

Python 3.8 or later with all requirements.txt dependencies installed, including torch>=1.7. To install run:

$ pip install -r requirements.txt

Environments

YOLOv5 may be run in any of the following up-to-date verified environments (with all dependencies including CUDA/CUDNN, Python and PyTorch preinstalled):

Status

CI CPU testing

If this badge is green, all YOLOv5 GitHub Actions Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training (train.py), testing (test.py), inference (detect.py) and export (export.py) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.

@smartinellimarco
Copy link

smartinellimarco commented Nov 30, 2020

Try changing:
model.model[-1].export = True
to False.
That should add an extra output of the same shape than the model in eager mode.
You'll need to do NMS over the new one (borrow it from predict.py).

I recommend you reading these issues entirely:

@dagap
Copy link
Author

dagap commented Nov 30, 2020

@MarcoCBA Thank you for the reply. So, the code in detect.py is doing NMS after the call to pred = model(image)[0]. I see the difference already before the call by simply making the forward pass over the exported model.

If I set it to False, I get this error:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

If I do everything on the CPU, the output for an input like image = torch.zeros(1, 3, 640, 640) is a tuple. The first element is a tensor of shape ([1, 25200, 85]).

The second element is a list of size 3 with tensors of shape:

torch.Size([1, 3, 80, 80, 85]
torch.Size([1, 3, 40, 40, 85])
torch.Size([1, 3, 20, 20, 85])

Do you reckon I need to replicate the detection layer?

@smartinellimarco
Copy link

smartinellimarco commented Dec 1, 2020

@MarcoCBA Thank you for the reply. So, the code in detect.py is doing NMS after the call to pred = model(image)[0]. I see the difference already before the call by simply making the forward pass over the exported model.

If I set it to False, I get this error:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

If I do everything on the CPU, the output for an input like image = torch.zeros(1, 3, 640, 640) is a tuple. The first element is a tensor of shape ([1, 25200, 85]).

The second element is a list of size 3 with tensors of shape:

torch.Size([1, 3, 80, 80, 85]
torch.Size([1, 3, 40, 40, 85])
torch.Size([1, 3, 20, 20, 85])

Do you reckon I need to replicate the detection layer?

The problem is that some constants are loading on a specific device. You could either move everything to CUDA (but the model will work only in cuda) or, you could modify the code inside the traced model (unzip it) in order to make it device agnostic.
I do not know why, but the guy that posted how to do that, deleted his comment.
Luckily I quoted him, so you can still read it:

The shapes are exactly what you are expecting. Im not sure what those three tensors even are, but If you do NMS over the first one of shape ([1, 25200, 85]), you will get boxes, scores and labels.

Here is the code for NMS:

def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.6, classes=None, agnostic=False, labels=()):
    """
    Performs Non-Maximum Suppression (NMS) on inference results
    Returns:
         detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
    """

    # Number of classes.
    nc = prediction[0].shape[1] - 5
    
    # Candidates.
    xc = prediction[..., 4] > conf_thres

    # Settings:
    # Minimum and maximum box width and height in pixels.
    min_wh, max_wh = 2, 4096

    # Maximum number of detections per image.
    max_det = 300
    
    # Timeout.
    time_limit = 10.0  
    
    # Require redundant detections.
    redundant = True
    
    # Multiple labels per box (adds 0.5ms/img).
    multi_label = nc > 1
    
    # Use Merge-NMS.
    merge = False

    t = time.time()
    output = [torch.zeros(0, 6)] * prediction.shape[0]
    for xi, x in enumerate(prediction):  # image index, image inference
        
        # Apply constraints:
        # Confidence.
        x = x[xc[xi]]

        # Cat apriori labels if autolabelling.
        if labels and len(labels[xi]):
            l = labels[xi]
            v = torch.zeros((len(l), nc + 5), device=x.device)
            v[:, :4] = l[:, 1:5]  # box
            v[:, 4] = 1.0  # conf
            v[range(len(l)), l[:, 0].long() + 5] = 1.0  # cls
            x = torch.cat((x, v), 0)

        # If none remain process next image.
        if not x.shape[0]:
            continue

        # Compute conf.
        x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf

        # Box (center x, center y, width, height) to (x1, y1, x2, y2).
        box = xywh2xyxy(x[:, :4])

        # Detections matrix nx6 (xyxy, conf, cls).
        if multi_label:
            i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
        else:

            # Best class only.
            conf, j = x[:, 5:].max(1, keepdim=True)
            x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]

        # Filter by class.
        if classes:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

        # If none remain process next image.
        # Number of boxes.
        n = x.shape[0]
        if not n:
            continue

        # Batched NMS:
        # Classes.
        c = x[:, 5:6] * (0 if agnostic else max_wh)
        
        # Boxes (offset by class), scores.
        boxes, scores = x[:, :4] + c, x[:, 4]
        
        # NMS.
        i = torchvision.ops.nms(boxes, scores, iou_thres)
        
        # Limit detections.
        if i.shape[0] > max_det:  # limit detections
            i = i[:max_det]
        if merge and (1 < n < 3E3):
            
            # Merge NMS (boxes merged using weighted mean).
            # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4).
            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
            weights = iou * scores[None]  # box weights
            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
            if redundant:
                i = i[iou.sum(1) > 1]  # require redundancy

        output[xi] = x[i]
        if (time.time() - t) > time_limit:
            break  # time limit exceeded

    return output


def xywh2xyxy(x):
    # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
    y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
    y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
    y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
    y[:, 2] = x[:, 0] + x[:, 2] / 2  # bottom right x
    y[:, 3] = x[:, 1] + x[:, 3] / 2  # bottom right y
    return y

just do

out = non_max_suppression(out, conf_thres=0.7)[0] # BATCH SIZE is 1.

@github-actions
Copy link
Contributor

github-actions bot commented Jan 1, 2021

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@github-actions github-actions bot added the Stale label Jan 1, 2021
@github-actions github-actions bot closed this as completed Jan 6, 2021
@sourabhyadav
Copy link

@MarcoCBA Thank you for the reply. So, the code in detect.py is doing NMS after the call to pred = model(image)[0]. I see the difference already before the call by simply making the forward pass over the exported model.

If I set it to False, I get this error:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

If I do everything on the CPU, the output for an input like image = torch.zeros(1, 3, 640, 640) is a tuple. The first element is a tensor of shape ([1, 25200, 85]).

The second element is a list of size 3 with tensors of shape:

torch.Size([1, 3, 80, 80, 85]
torch.Size([1, 3, 40, 40, 85])
torch.Size([1, 3, 20, 20, 85])

Do you reckon I need to replicate the detection layer?

@dagap I am also facing the same issue:

At normal inference, the output is a torch tensor and the shape is consistent wrt to batch size:

Input shape:
imgs size: torch.Size([2, 3, 384, 640])

Output shape:
dtype=torch.float16) shape: torch.Size([2, 15120, 85])

However, in the torchscript output is a list and the length of 3 even when the input batch size is 1 or 2.

Input Shape:
imgs size: torch.Size([1, 3, 384, 640])

Output Shape:
inf_out[0] : torch.Size([1, 3, 48, 80, 85])
inf_out[1] : torch.Size([1, 3, 24, 40, 85])
inf_out[2] : torch.Size([1, 3, 12, 20, 85])

The doubts I am having is:
Although the overall output size matches but I am not sure how to pack these properly for further processing like nms and box filtering etc.

The main doubt is unlike in your case you were getting the 1st tensor as correct tensor but I am not getting the same?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Stale
Projects
None yet
Development

No branches or pull requests

3 participants