Skip to content

Commit

Permalink
add L1 support for KNN & Chamfer
Browse files Browse the repository at this point in the history
Summary:
Added L1 norm for KNN and chamfer op
* The norm is now specified with a variable `norm` which can only be 1 or 2

Reviewed By: bottler

Differential Revision: D35419637

fbshipit-source-id: 77813fec650b30c28342af90d5ed02c89133e136
  • Loading branch information
gkioxari authored and facebook-github-bot committed Apr 10, 2022
1 parent 4b94649 commit 67fff95
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 130 deletions.
76 changes: 53 additions & 23 deletions pytorch3d/csrc/knn/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ __global__ void KNearestNeighborKernelV0(
const size_t P1,
const size_t P2,
const size_t D,
const size_t K) {
const size_t K,
const size_t norm) {
// Store both dists and indices for knn in global memory.
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
Expand All @@ -56,7 +57,8 @@ __global__ void KNearestNeighborKernelV0(
scalar_t coord1 = points1[n * P1 * D + p1 * D + d];
scalar_t coord2 = points2[n * P2 * D + p2 * D + d];
scalar_t diff = coord1 - coord2;
dist += diff * diff;
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
dist += norm_diff;
}
mink.add(dist, p2);
}
Expand All @@ -74,7 +76,8 @@ __global__ void KNearestNeighborKernelV1(
const size_t N,
const size_t P1,
const size_t P2,
const size_t K) {
const size_t K,
const size_t norm) {
// Same idea as the previous version, but hoist D into a template argument
// so we can cache the current point in a thread-local array. We still store
// the current best K dists and indices in global memory, so this should work
Expand All @@ -99,7 +102,8 @@ __global__ void KNearestNeighborKernelV1(
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d];
dist += diff * diff;
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
dist += norm_diff;
}
mink.add(dist, p2);
}
Expand All @@ -121,10 +125,11 @@ struct KNearestNeighborV1Functor {
const size_t N,
const size_t P1,
const size_t P2,
const size_t K) {
const size_t K,
const size_t norm) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads, 0, stream>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K);
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K, norm);
}
};

Expand All @@ -138,7 +143,8 @@ __global__ void KNearestNeighborKernelV2(
int64_t* __restrict__ idxs,
const int64_t N,
const int64_t P1,
const int64_t P2) {
const int64_t P2,
const size_t norm) {
// Same general implementation as V2, but also hoist K into a template arg.
scalar_t cur_point[D];
scalar_t min_dists[K];
Expand All @@ -161,7 +167,8 @@ __global__ void KNearestNeighborKernelV2(
for (int d = 0; d < D; ++d) {
int offset = n * P2 * D + p2 * D + d;
scalar_t diff = cur_point[d] - points2[offset];
dist += diff * diff;
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
dist += norm_diff;
}
mink.add(dist, p2);
}
Expand All @@ -186,10 +193,11 @@ struct KNearestNeighborKernelV2Functor {
int64_t* __restrict__ idxs,
const int64_t N,
const int64_t P1,
const int64_t P2) {
const int64_t P2,
const size_t norm) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm);
}
};

Expand All @@ -203,7 +211,8 @@ __global__ void KNearestNeighborKernelV3(
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2) {
const size_t P2,
const size_t norm) {
// Same idea as V2, but use register indexing for thread-local arrays.
// Enabling sorting for this version leads to huge slowdowns; I suspect
// that it forces min_dists into local memory rather than registers.
Expand All @@ -229,7 +238,8 @@ __global__ void KNearestNeighborKernelV3(
for (int d = 0; d < D; ++d) {
int offset = n * P2 * D + p2 * D + d;
scalar_t diff = cur_point[d] - points2[offset];
dist += diff * diff;
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
dist += norm_diff;
}
mink.add(dist, p2);
}
Expand All @@ -254,10 +264,11 @@ struct KNearestNeighborKernelV3Functor {
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2) {
const size_t P2,
const size_t norm) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm);
}
};

Expand Down Expand Up @@ -305,7 +316,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K,
const int norm,
const int K,
int version) {
// Check inputs are on the same device
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
Expand All @@ -324,6 +336,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const auto D = p2.size(2);
const int64_t K_64 = K;

TORCH_CHECK((norm == 1) || (norm == 2), "Norm must be 1 or 2.");

TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension");
auto long_dtype = lengths1.options().dtype(at::kLong);
auto idxs = at::zeros({N, P1, K}, long_dtype);
Expand Down Expand Up @@ -366,7 +380,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
P1,
P2,
D,
K);
K,
norm);
}));
} else if (version == 1) {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
Expand All @@ -387,7 +402,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
N,
P1,
P2,
K);
K,
norm);
}));
} else if (version == 2) {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
Expand All @@ -410,7 +426,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
idxs.data_ptr<int64_t>(),
N,
P1,
P2);
P2,
norm);
}));
} else if (version == 3) {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
Expand All @@ -433,7 +450,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
idxs.data_ptr<int64_t>(),
N,
P1,
P2);
P2,
norm);
}));
}
AT_CUDA_CHECK(cudaGetLastError());
Expand All @@ -459,7 +477,8 @@ __global__ void KNearestNeighborBackwardKernel(
const size_t P1,
const size_t P2,
const size_t K,
const size_t D) {
const size_t D,
const size_t norm) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = gridDim.x * blockDim.x;

Expand All @@ -481,8 +500,17 @@ __global__ void KNearestNeighborBackwardKernel(
if (p2_idx == -1) {
continue;
}
const float diff = 2.0 * grad_dist *
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
float diff = 0.0;
if (norm == 1) {
float sign =
(p1[n * P1 * D + p1_idx * D + d] > p2[n * P2 * D + p2_idx * D + d])
? 1.0
: -1.0;
diff = grad_dist * sign;
} else { // norm is 2
diff = 2.0 * grad_dist *
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
}
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff);
}
Expand All @@ -495,6 +523,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
int norm,
const at::Tensor& grad_dists) {
// Check inputs are on the same device
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
Expand Down Expand Up @@ -547,7 +576,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
P1,
P2,
K,
D);
D,
norm);

AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_p1, grad_p2);
Expand Down
27 changes: 18 additions & 9 deletions pytorch3d/csrc/knn/knn.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
// containing P2 points of dimension D.
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
// norm: int specifying the norm for the distance (1 for L1, 2 for L2)
// K: int giving the number of nearest points to return.
// version: Integer telling which implementation to use.
//
Expand All @@ -41,35 +42,39 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K);
const int norm,
const int K);

// CUDA implementation
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K,
int version);
const int norm,
const int K,
const int version);

// Implementation which is exposed.
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K,
int version) {
const int norm,
const int K,
const int version) {
if (p1.is_cuda() || p2.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(p1);
CHECK_CUDA(p2);
return KNearestNeighborIdxCuda(p1, p2, lengths1, lengths2, K, version);
return KNearestNeighborIdxCuda(
p1, p2, lengths1, lengths2, norm, K, version);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K);
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K);
}

// Compute gradients with respect to p1 and p2
Expand All @@ -86,6 +91,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
// It is padded with zeros so that it can be used easily in a later
// gather() operation. This is computed from the forward pass.
// norm: int specifying the norm for the distance (1 for L1, 2 for L2)
// grad_dists: FLoatTensor of shape (N, P1, K) which contains the input
// gradients.
//
Expand All @@ -102,6 +108,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const int norm,
const at::Tensor& grad_dists);

// CUDA implementation
Expand All @@ -111,6 +118,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const int norm,
const at::Tensor& grad_dists);

// Implementation which is exposed.
Expand All @@ -120,19 +128,20 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackward(
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const int norm,
const at::Tensor& grad_dists) {
if (p1.is_cuda() || p2.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(p1);
CHECK_CUDA(p2);
return KNearestNeighborBackwardCuda(
p1, p2, lengths1, lengths2, idxs, grad_dists);
p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return KNearestNeighborBackwardCpu(
p1, p2, lengths1, lengths2, idxs, grad_dists);
p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
}

// Utility to check whether a KNN version can be used.
Expand Down
20 changes: 16 additions & 4 deletions pytorch3d/csrc/knn/knn_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K) {
const int norm,
const int K) {
const int N = p1.size(0);
const int P1 = p1.size(1);
const int D = p1.size(2);
Expand All @@ -41,7 +42,11 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
float dist = 0;
for (int d = 0; d < D; ++d) {
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
dist += diff * diff;
if (norm == 1) {
dist += abs(diff);
} else { // norm is 2 (default)
dist += diff * diff;
}
}
int size = static_cast<int>(q.size());
if (size < K || dist < std::get<0>(q.top())) {
Expand Down Expand Up @@ -73,6 +78,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const int norm,
const at::Tensor& grad_dists) {
const int N = p1.size(0);
const int P1 = p1.size(1);
Expand Down Expand Up @@ -104,8 +110,14 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
continue;
}
for (int64_t d = 0; d < D; ++d) {
const float diff =
2.0f * grad_dists_a[n][i1][k] * (p1_a[n][i1][d] - p2_a[n][i2][d]);
float diff = 0.0;
if (norm == 1) {
float sign = (p1_a[n][i1][d] > p2_a[n][i2][d]) ? 1.0 : -1.0;
diff = grad_dists_a[n][i1][k] * sign;
} else { // norm is 2 (default)
diff = 2.0f * grad_dists_a[n][i1][k] *
(p1_a[n][i1][d] - p2_a[n][i2][d]);
}
grad_p1_a[n][i1][d] += diff;
grad_p2_a[n][i2][d] += -1.0f * diff;
}
Expand Down
Loading

0 comments on commit 67fff95

Please sign in to comment.