Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Remove CUDA part from diff_iou_rotated #2067

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

filaPro
Copy link
Contributor

@filaPro filaPro commented Jun 21, 2022

Motivation

Fix #1922. Following lilanxiao/Rotated_IoU#39#issuecomment-1146352088 we remove cuda and cpp part of diff_iou_rotated in favour of pure pytorch implementation. It has negligible affect on speed however the overall accuracy in corner cases is much better.

Modification

Replace sorting vertices from cuda to pytorch. Split test to GPU and CPU.

BC-breaking (Optional)

No, the public API is not changed. However we may want to move it from mmcv.ops as it does not require cuda now. mmdet3d does not use it in master. I've checked this PR with my FCAF3D PR mmdetection3d#1547. Also diff_iou_rotated is used in mmrotate master. Maybe @ZwwWayne or @Tai-Wang can have a look on mmdetection3d connection and @zytx121 for mmrotate.

Benchmark

Setup: ubuntu 18.04.6, nvidia driver 470.129.06, nvidia geforce rtx 3090, pytorch/pytorch:1.8.1-cuda10.2-cudnn7-devel

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch

from mmcv.ops import diff_iou_rotated_2d, diff_iou_rotated_3d

def test():
    np_boxes_2d_1 = np.random.random((100, 1000, 5)).astype(np.float32)
    np_boxes_2d_2 = np.random.random((100, 1000, 5)).astype(np.float32)
    boxes1 = torch.from_numpy(np_boxes_2d_1).cuda()
    boxes2 = torch.from_numpy(np_boxes_2d_2).cuda()
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    ious = diff_iou_rotated_2d(boxes1, boxes2)
    end.record()
    torch.cuda.synchronize()
    return start.elapsed_time(end)

a = [test() for _ in range(100)][1:]
print(f'mean={np.mean(a):.4f}, std={np.std(a):.4f}')

master >>> mean=7.1313, std=0.2482
this pr >>> mean=11.1537, std=0.3251

@zhouzaida
Copy link
Collaborator

Hi @filaPro , thanks for your contributions. Could you add some benchmark results in the PR description? You can refer to #1718 (comment).

@filaPro
Copy link
Contributor Author

filaPro commented Jun 21, 2022

Hi @zhouzaida , done. I haven't tried torch.cuda.Event before, hope my script is right. Also don't think that the decrease of speed is important here.

@zhouzaida zhouzaida requested review from grimoire and ZwwWayne June 21, 2022 14:06
Copy link
Member

@grimoire grimoire left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Tai-Wang
Copy link
Member

Tai-Wang commented Jul 1, 2022

Only need to fix minor comments. Others LGTM.

@Tai-Wang
Copy link
Member

Tai-Wang commented Jul 1, 2022

@filaPro Great contribution! BTW, I am curious about the iou computation used in mmdet3d. Sometimes I feel like there are also some corner cases that BEV IoU can not be precisely computed, for example, for two strictly overlapped boxes on KITTI. Do you have any idea about it and is it related to similar problems as shown in this PR?

@filaPro
Copy link
Contributor Author

filaPro commented Jul 6, 2022

Hi, @Tai-Wang.
I think yes, it is the similar problem. The thing is that the intersection of 2 rotated boxes is a convex polygon with from 1 to 8 vertices. We first determine them and then simply (because it is convex) calculate the area. But when the edges of the boxes are collinear it is hard to determine the correct intersection coordinates. And for some reason this numerical instability appears more with cuda code. As the pytorch implementation is more stable we can consider removing mmcv.ops.box_iou_rotated in favour of this mmcv.ops.diff_iou_rotated in the future.

Copy link
Contributor

@zytx121 zytx121 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! LGTM.

@Tai-Wang
Copy link
Member

LGTM. Just need to resolve conflicts

@zhouzaida
Copy link
Collaborator

Hi @filaPro , this PR can be merged after resolving conflicts.

@zytx121
Copy link
Contributor

zytx121 commented Oct 25, 2022

Hi @filaPro
According to #2335, even if the two boxes overlap, there is still a gradient. Does this meet our expectations?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

wrong result of differentiable 3D IoU, occurring larger than 1
5 participants