Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug to allow pointnet to train in fp16 #1207

Merged
merged 4 commits into from
Feb 16, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mmdet3d/ops/gather_points/gather_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def forward(ctx, features: torch.Tensor,

B, npoint = indices.size()
_, C, N = features.size()
output = torch.cuda.FloatTensor(B, C, npoint)
output = features.new_zeros((B, C, npoint))

gather_points_ext.gather_points_wrapper(B, C, N, npoint, features,
indices, output)
Expand All @@ -41,7 +41,7 @@ def backward(ctx, grad_out):
idx, C, N = ctx.for_backwards
B, npoint = idx.size()

grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
grad_features = grad_out.new_zeros((B, C, N))
grad_out_data = grad_out.data.contiguous()
gather_points_ext.gather_points_grad_wrapper(B, C, N, npoint,
grad_out_data, idx,
Expand Down
55 changes: 26 additions & 29 deletions mmdet3d/ops/gather_points/src/gather_points.cpp
Original file line number Diff line number Diff line change
@@ -1,57 +1,54 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/TensorUtils.h>
#include <THC/THC.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>

#include <vector>


extern THCState *state;

int gather_points_wrapper(int b, int c, int n, int npoints,
at::Tensor points_tensor, at::Tensor idx_tensor,
at::Tensor out_tensor);
at::Tensor& points_tensor, at::Tensor& idx_tensor,
at::Tensor& out_tensor);

void gather_points_kernel_launcher(int b, int c, int n, int npoints,
const float *points, const int *idx,
float *out, cudaStream_t stream);
const at::Tensor& points_tensor,
const at::Tensor& idx_tensor,
at::Tensor& out_tensor);

int gather_points_grad_wrapper(int b, int c, int n, int npoints,
at::Tensor grad_out_tensor,
at::Tensor idx_tensor,
at::Tensor grad_points_tensor);
at::Tensor& grad_out_tensor,
at::Tensor& idx_tensor,
at::Tensor& grad_points_tensor);

void gather_points_grad_kernel_launcher(int b, int c, int n, int npoints,
const float *grad_out, const int *idx,
float *grad_points,
cudaStream_t stream);
const at::Tensor& grad_out_tensor,
const at::Tensor& idx_tensor,
at::Tensor& grad_points_tensor);

int gather_points_wrapper(int b, int c, int n, int npoints,
at::Tensor points_tensor, at::Tensor idx_tensor,
at::Tensor out_tensor) {
const float *points = points_tensor.data_ptr<float>();
const int *idx = idx_tensor.data_ptr<int>();
float *out = out_tensor.data_ptr<float>();

cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
gather_points_kernel_launcher(b, c, n, npoints, points, idx, out, stream);
at::Tensor& points_tensor, at::Tensor& idx_tensor,
at::Tensor& out_tensor)
{
gather_points_kernel_launcher(b, c, n, npoints, points_tensor, idx_tensor, out_tensor);
return 1;
}

int gather_points_grad_wrapper(int b, int c, int n, int npoints,
at::Tensor grad_out_tensor,
at::Tensor idx_tensor,
at::Tensor grad_points_tensor) {
const float *grad_out = grad_out_tensor.data_ptr<float>();
const int *idx = idx_tensor.data_ptr<int>();
float *grad_points = grad_points_tensor.data_ptr<float>();

cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
gather_points_grad_kernel_launcher(b, c, n, npoints, grad_out, idx,
grad_points, stream);
at::Tensor& grad_out_tensor,
at::Tensor& idx_tensor,
at::Tensor& grad_points_tensor)
{
gather_points_grad_kernel_launcher(b, c, n, npoints, grad_out_tensor, idx_tensor,
grad_points_tensor);
return 1;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("gather_points_wrapper", &gather_points_wrapper,
"gather_points_wrapper");
m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper,
Expand Down
70 changes: 50 additions & 20 deletions mmdet3d/ops/gather_points/src/gather_points_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
#include <stdio.h>
#include <stdlib.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/types.h>

#include <ATen/cuda/CUDAApplyUtils.cuh>

#define TOTAL_THREADS 1024
#define THREADS_PER_BLOCK 256
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))

template <typename scalar_t>
__global__ void gather_points_kernel(int b, int c, int n, int m,
const float *__restrict__ points,
const scalar_t *__restrict__ points,
const int *__restrict__ idx,
float *__restrict__ out) {
scalar_t *__restrict__ out) {
// points: (B, C, N)
// idx: (B, M)
// output:
Expand All @@ -26,32 +33,44 @@ __global__ void gather_points_kernel(int b, int c, int n, int m,
}

void gather_points_kernel_launcher(int b, int c, int n, int npoints,
const float *points, const int *idx,
float *out, cudaStream_t stream) {
const at::Tensor& points_tensor,
const at::Tensor& idx_tensor,
at::Tensor& out_tensor)
{
// points: (B, C, N)
// idx: (B, npoints)
// output:
// out: (B, C, npoints)

cudaError_t err;
dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c,
b); // blockIdx.x(col), blockIdx.y(row)
b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);

gather_points_kernel<<<blocks, threads, 0, stream>>>(b, c, n, npoints, points,
idx, out);

cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
out_tensor.scalar_type(), "gather_points_kernel",
[&]
{
const scalar_t *points = points_tensor.data_ptr<scalar_t>();
const int *idx = idx_tensor.data_ptr<int>();
scalar_t *out = out_tensor.data_ptr<scalar_t>();
gather_points_kernel<<<blocks, threads, 0, stream>>>(b, c, n, npoints, points,
idx, out);
});
err = cudaGetLastError();
if (cudaSuccess != err) {
if (cudaSuccess != err)
{
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}

template <typename scalar_t>
__global__ void gather_points_grad_kernel(int b, int c, int n, int m,
const float *__restrict__ grad_out,
const scalar_t *__restrict__ grad_out,
const int *__restrict__ idx,
float *__restrict__ grad_points) {
scalar_t *__restrict__ grad_points) {
// grad_out: (B, C, M)
// idx: (B, M)
// output:
Expand All @@ -70,25 +89,36 @@ __global__ void gather_points_grad_kernel(int b, int c, int n, int m,
}

void gather_points_grad_kernel_launcher(int b, int c, int n, int npoints,
const float *grad_out, const int *idx,
float *grad_points,
cudaStream_t stream) {
const at::Tensor& grad_out_tensor,
const at::Tensor& idx_tensor,
at::Tensor& grad_points_tensor)
{
// grad_out: (B, C, npoints)
// idx: (B, npoints)
// output:
// grad_points: (B, C, N)

cudaError_t err;
dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c,
b); // blockIdx.x(col), blockIdx.y(row)
b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);

gather_points_grad_kernel<<<blocks, threads, 0, stream>>>(
b, c, n, npoints, grad_out, idx, grad_points);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_points_tensor.scalar_type(), "gather_points_grad_kernel",
[&]
{
const scalar_t *grad_out = grad_out_tensor.data_ptr<scalar_t>();
const int *idx = idx_tensor.data_ptr<int>();
scalar_t *grad_points = grad_points_tensor.data_ptr<scalar_t>();
gather_points_grad_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
b, c, n, npoints, grad_out, idx, grad_points);
});

err = cudaGetLastError();
if (cudaSuccess != err) {
if (cudaSuccess != err)
{
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
}
6 changes: 5 additions & 1 deletion mmdet3d/ops/group_points/group_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch import nn as nn
from torch.autograd import Function
from typing import Tuple

from mmcv.runner import force_fp32
from ..ball_query import ball_query
from ..knn import knn
from . import group_points_ext
Expand Down Expand Up @@ -60,7 +60,9 @@ def __init__(self,
if self.max_radius is None:
assert not self.normalize_xyz, \
'can not normalize grouped xyz when max_radius is None'
self.fp16_enabled = False

@force_fp32()
def forward(self, points_xyz, center_xyz, features=None):
"""forward.

Expand Down Expand Up @@ -141,7 +143,9 @@ class GroupAll(nn.Module):
def __init__(self, use_xyz: bool = True):
super().__init__()
self.use_xyz = use_xyz
self.fp16_enabled = False

@force_fp32()
def forward(self,
xyz: torch.Tensor,
new_xyz: torch.Tensor,
Expand Down
2 changes: 2 additions & 0 deletions tests/test_models/test_common_modules/test_pointnet_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ def test_gather_points():
[-0.7172, 0.0462, -0.6227, -0.7172, -0.7172, -0.7172]]]).cuda()

assert torch.allclose(output, expected_output)
output_half = gather_points(features.half(), idx)
assert torch.allclose(output_half, expected_output.half())


def test_three_interpolate():
Expand Down