Skip to content

Commit 01b5f7b

Browse files
bottlerfacebook-github-bot
authored andcommitted
heterogenous KNN
Summary: Interface and working implementation of ragged KNN. Benchmarks (which aren't ragged) haven't slowed. New benchmark shows that ragged is faster than non-ragged of the same shape. Reviewed By: jcjohnson Differential Revision: D20696507 fbshipit-source-id: 21b80f71343a3475c8d3ee0ce2680f92f0fae4de
1 parent 29b9c44 commit 01b5f7b

File tree

6 files changed

+332
-84
lines changed

6 files changed

+332
-84
lines changed

pytorch3d/csrc/knn/knn.cu

+80-35
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,40 @@
88
#include "dispatch.cuh"
99
#include "mink.cuh"
1010

11+
// A chunk of work is blocksize-many points of P1.
12+
// The number of potential chunks to do is N*(1+(P1-1)/blocksize)
13+
// call (1+(P1-1)/blocksize) chunks_per_cloud
14+
// These chunks are divided among the gridSize-many blocks.
15+
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
16+
// In chunk i, we work on cloud i/chunks_per_cloud on points starting from
17+
// blocksize*(i%chunks_per_cloud).
18+
1119
template <typename scalar_t>
1220
__global__ void KNearestNeighborKernelV0(
1321
const scalar_t* __restrict__ points1,
1422
const scalar_t* __restrict__ points2,
23+
const int64_t* __restrict__ lengths1,
24+
const int64_t* __restrict__ lengths2,
1525
scalar_t* __restrict__ dists,
1626
int64_t* __restrict__ idxs,
1727
const size_t N,
1828
const size_t P1,
1929
const size_t P2,
2030
const size_t D,
2131
const size_t K) {
22-
// Stupid version: Make each thread handle one query point and loop over
23-
// all P2 target points. There are N * P1 input points to handle, so
24-
// do a trivial parallelization over threads.
2532
// Store both dists and indices for knn in global memory.
26-
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
27-
const int num_threads = blockDim.x * gridDim.x;
28-
for (int np = tid; np < N * P1; np += num_threads) {
29-
int n = np / P1;
30-
int p1 = np % P1;
33+
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
34+
const int64_t chunks_to_do = N * chunks_per_cloud;
35+
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
36+
const int64_t n = chunk / chunks_per_cloud;
37+
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
38+
int64_t p1 = start_point + threadIdx.x;
39+
if (p1 >= lengths1[n])
40+
continue;
3141
int offset = n * P1 * K + p1 * K;
42+
int64_t length2 = lengths2[n];
3243
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K);
33-
for (int p2 = 0; p2 < P2; ++p2) {
44+
for (int p2 = 0; p2 < length2; ++p2) {
3445
// Find the distance between points1[n, p1] and points[n, p2]
3546
scalar_t dist = 0;
3647
for (int d = 0; d < D; ++d) {
@@ -48,6 +59,8 @@ template <typename scalar_t, int64_t D>
4859
__global__ void KNearestNeighborKernelV1(
4960
const scalar_t* __restrict__ points1,
5061
const scalar_t* __restrict__ points2,
62+
const int64_t* __restrict__ lengths1,
63+
const int64_t* __restrict__ lengths2,
5164
scalar_t* __restrict__ dists,
5265
int64_t* __restrict__ idxs,
5366
const size_t N,
@@ -58,18 +71,22 @@ __global__ void KNearestNeighborKernelV1(
5871
// so we can cache the current point in a thread-local array. We still store
5972
// the current best K dists and indices in global memory, so this should work
6073
// for very large K and fairly large D.
61-
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
62-
const int num_threads = blockDim.x * gridDim.x;
6374
scalar_t cur_point[D];
64-
for (int np = tid; np < N * P1; np += num_threads) {
65-
int n = np / P1;
66-
int p1 = np % P1;
75+
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
76+
const int64_t chunks_to_do = N * chunks_per_cloud;
77+
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
78+
const int64_t n = chunk / chunks_per_cloud;
79+
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
80+
int64_t p1 = start_point + threadIdx.x;
81+
if (p1 >= lengths1[n])
82+
continue;
6783
for (int d = 0; d < D; ++d) {
6884
cur_point[d] = points1[n * P1 * D + p1 * D + d];
6985
}
7086
int offset = n * P1 * K + p1 * K;
87+
int64_t length2 = lengths2[n];
7188
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K);
72-
for (int p2 = 0; p2 < P2; ++p2) {
89+
for (int p2 = 0; p2 < length2; ++p2) {
7390
// Find the distance between cur_point and points[n, p2]
7491
scalar_t dist = 0;
7592
for (int d = 0; d < D; ++d) {
@@ -89,40 +106,48 @@ struct KNearestNeighborV1Functor {
89106
size_t threads,
90107
const scalar_t* __restrict__ points1,
91108
const scalar_t* __restrict__ points2,
109+
const int64_t* __restrict__ lengths1,
110+
const int64_t* __restrict__ lengths2,
92111
scalar_t* __restrict__ dists,
93112
int64_t* __restrict__ idxs,
94113
const size_t N,
95114
const size_t P1,
96115
const size_t P2,
97116
const size_t K) {
98-
KNearestNeighborKernelV1<scalar_t, D>
99-
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2, K);
117+
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads>>>(
118+
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K);
100119
}
101120
};
102121

103122
template <typename scalar_t, int64_t D, int64_t K>
104123
__global__ void KNearestNeighborKernelV2(
105124
const scalar_t* __restrict__ points1,
106125
const scalar_t* __restrict__ points2,
126+
const int64_t* __restrict__ lengths1,
127+
const int64_t* __restrict__ lengths2,
107128
scalar_t* __restrict__ dists,
108129
int64_t* __restrict__ idxs,
109130
const int64_t N,
110131
const int64_t P1,
111132
const int64_t P2) {
112133
// Same general implementation as V2, but also hoist K into a template arg.
113-
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
114-
const int num_threads = blockDim.x * gridDim.x;
115134
scalar_t cur_point[D];
116135
scalar_t min_dists[K];
117136
int min_idxs[K];
118-
for (int np = tid; np < N * P1; np += num_threads) {
119-
int n = np / P1;
120-
int p1 = np % P1;
137+
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
138+
const int64_t chunks_to_do = N * chunks_per_cloud;
139+
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
140+
const int64_t n = chunk / chunks_per_cloud;
141+
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
142+
int64_t p1 = start_point + threadIdx.x;
143+
if (p1 >= lengths1[n])
144+
continue;
121145
for (int d = 0; d < D; ++d) {
122146
cur_point[d] = points1[n * P1 * D + p1 * D + d];
123147
}
148+
int64_t length2 = lengths2[n];
124149
MinK<scalar_t, int> mink(min_dists, min_idxs, K);
125-
for (int p2 = 0; p2 < P2; ++p2) {
150+
for (int p2 = 0; p2 < length2; ++p2) {
126151
scalar_t dist = 0;
127152
for (int d = 0; d < D; ++d) {
128153
int offset = n * P2 * D + p2 * D + d;
@@ -146,20 +171,24 @@ struct KNearestNeighborKernelV2Functor {
146171
size_t threads,
147172
const scalar_t* __restrict__ points1,
148173
const scalar_t* __restrict__ points2,
174+
const int64_t* __restrict__ lengths1,
175+
const int64_t* __restrict__ lengths2,
149176
scalar_t* __restrict__ dists,
150177
int64_t* __restrict__ idxs,
151178
const int64_t N,
152179
const int64_t P1,
153180
const int64_t P2) {
154-
KNearestNeighborKernelV2<scalar_t, D, K>
155-
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2);
181+
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads>>>(
182+
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
156183
}
157184
};
158185

159186
template <typename scalar_t, int D, int K>
160187
__global__ void KNearestNeighborKernelV3(
161188
const scalar_t* __restrict__ points1,
162189
const scalar_t* __restrict__ points2,
190+
const int64_t* __restrict__ lengths1,
191+
const int64_t* __restrict__ lengths2,
163192
scalar_t* __restrict__ dists,
164193
int64_t* __restrict__ idxs,
165194
const size_t N,
@@ -169,19 +198,23 @@ __global__ void KNearestNeighborKernelV3(
169198
// Enabling sorting for this version leads to huge slowdowns; I suspect
170199
// that it forces min_dists into local memory rather than registers.
171200
// As a result this version is always unsorted.
172-
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
173-
const int num_threads = blockDim.x * gridDim.x;
174201
scalar_t cur_point[D];
175202
scalar_t min_dists[K];
176203
int min_idxs[K];
177-
for (int np = tid; np < N * P1; np += num_threads) {
178-
int n = np / P1;
179-
int p1 = np % P1;
204+
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
205+
const int64_t chunks_to_do = N * chunks_per_cloud;
206+
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
207+
const int64_t n = chunk / chunks_per_cloud;
208+
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
209+
int64_t p1 = start_point + threadIdx.x;
210+
if (p1 >= lengths1[n])
211+
continue;
180212
for (int d = 0; d < D; ++d) {
181213
cur_point[d] = points1[n * P1 * D + p1 * D + d];
182214
}
215+
int64_t length2 = lengths2[n];
183216
RegisterMinK<scalar_t, int, K> mink(min_dists, min_idxs);
184-
for (int p2 = 0; p2 < P2; ++p2) {
217+
for (int p2 = 0; p2 < length2; ++p2) {
185218
scalar_t dist = 0;
186219
for (int d = 0; d < D; ++d) {
187220
int offset = n * P2 * D + p2 * D + d;
@@ -205,13 +238,15 @@ struct KNearestNeighborKernelV3Functor {
205238
size_t threads,
206239
const scalar_t* __restrict__ points1,
207240
const scalar_t* __restrict__ points2,
241+
const int64_t* __restrict__ lengths1,
242+
const int64_t* __restrict__ lengths2,
208243
scalar_t* __restrict__ dists,
209244
int64_t* __restrict__ idxs,
210245
const size_t N,
211246
const size_t P1,
212247
const size_t P2) {
213-
KNearestNeighborKernelV3<scalar_t, D, K>
214-
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2);
248+
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads>>>(
249+
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
215250
}
216251
};
217252

@@ -257,6 +292,8 @@ int ChooseVersion(const int64_t D, const int64_t K) {
257292
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
258293
const at::Tensor& p1,
259294
const at::Tensor& p2,
295+
const at::Tensor& lengths1,
296+
const at::Tensor& lengths2,
260297
int K,
261298
int version) {
262299
const auto N = p1.size(0);
@@ -267,8 +304,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
267304

268305
AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension");
269306
auto long_dtype = p1.options().dtype(at::kLong);
270-
auto idxs = at::full({N, P1, K}, -1, long_dtype);
271-
auto dists = at::full({N, P1, K}, -1, p1.options());
307+
auto idxs = at::zeros({N, P1, K}, long_dtype);
308+
auto dists = at::zeros({N, P1, K}, p1.options());
272309

273310
if (version < 0) {
274311
version = ChooseVersion(D, K);
@@ -294,6 +331,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
294331
<<<blocks, threads>>>(
295332
p1.data_ptr<scalar_t>(),
296333
p2.data_ptr<scalar_t>(),
334+
lengths1.data_ptr<int64_t>(),
335+
lengths2.data_ptr<int64_t>(),
297336
dists.data_ptr<scalar_t>(),
298337
idxs.data_ptr<int64_t>(),
299338
N,
@@ -314,6 +353,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
314353
threads,
315354
p1.data_ptr<scalar_t>(),
316355
p2.data_ptr<scalar_t>(),
356+
lengths1.data_ptr<int64_t>(),
357+
lengths2.data_ptr<int64_t>(),
317358
dists.data_ptr<scalar_t>(),
318359
idxs.data_ptr<int64_t>(),
319360
N,
@@ -336,6 +377,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
336377
threads,
337378
p1.data_ptr<scalar_t>(),
338379
p2.data_ptr<scalar_t>(),
380+
lengths1.data_ptr<int64_t>(),
381+
lengths2.data_ptr<int64_t>(),
339382
dists.data_ptr<scalar_t>(),
340383
idxs.data_ptr<int64_t>(),
341384
N,
@@ -357,6 +400,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
357400
threads,
358401
p1.data_ptr<scalar_t>(),
359402
p2.data_ptr<scalar_t>(),
403+
lengths1.data_ptr<int64_t>(),
404+
lengths2.data_ptr<int64_t>(),
360405
dists.data_ptr<scalar_t>(),
361406
idxs.data_ptr<int64_t>(),
362407
N,

pytorch3d/csrc/knn/knn.h

+22-7
Original file line numberDiff line numberDiff line change
@@ -13,42 +13,57 @@
1313
// containing P1 points of dimension D.
1414
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
1515
// containing P2 points of dimension D.
16+
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
17+
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
1618
// K: int giving the number of nearest points to return.
1719
// sorted: bool telling whether to sort the K returned points by their
1820
// distance.
1921
// version: Integer telling which implementation to use.
20-
// TODO(jcjohns): Document this more, or maybe remove it before landing.
2122
//
2223
// Returns:
2324
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
24-
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
25-
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
25+
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
26+
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
27+
// It is padded with zeros so that it can be used easily in a later
28+
// gather() operation.
29+
//
30+
// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared
31+
// distance from each point p1[n, p, :] to its K neighbors
32+
// p2[n, p1_neighbor_idx[n, p, k], :].
2633

2734
// CPU implementation.
28-
std::tuple<at::Tensor, at::Tensor>
29-
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K);
35+
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
36+
const at::Tensor& p1,
37+
const at::Tensor& p2,
38+
const at::Tensor& lengths1,
39+
const at::Tensor& lengths2,
40+
int K);
3041

3142
// CUDA implementation
3243
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
3344
const at::Tensor& p1,
3445
const at::Tensor& p2,
46+
const at::Tensor& lengths1,
47+
const at::Tensor& lengths2,
3548
int K,
3649
int version);
3750

3851
// Implementation which is exposed.
3952
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
4053
const at::Tensor& p1,
4154
const at::Tensor& p2,
55+
const at::Tensor& lengths1,
56+
const at::Tensor& lengths2,
4257
int K,
4358
int version) {
4459
if (p1.is_cuda() || p2.is_cuda()) {
4560
#ifdef WITH_CUDA
4661
CHECK_CONTIGUOUS_CUDA(p1);
4762
CHECK_CONTIGUOUS_CUDA(p2);
48-
return KNearestNeighborIdxCuda(p1, p2, K, version);
63+
return KNearestNeighborIdxCuda(p1, p2, lengths1, lengths2, K, version);
4964
#else
5065
AT_ERROR("Not compiled with GPU support.");
5166
#endif
5267
}
53-
return KNearestNeighborIdxCpu(p1, p2, K);
68+
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K);
5469
}

pytorch3d/csrc/knn/knn_cpu.cpp

+13-5
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,35 @@
44
#include <queue>
55
#include <tuple>
66

7-
std::tuple<at::Tensor, at::Tensor>
8-
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K) {
7+
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
8+
const at::Tensor& p1,
9+
const at::Tensor& p2,
10+
const at::Tensor& lengths1,
11+
const at::Tensor& lengths2,
12+
int K) {
913
const int N = p1.size(0);
1014
const int P1 = p1.size(1);
1115
const int D = p1.size(2);
1216
const int P2 = p2.size(1);
1317

1418
auto long_opts = p1.options().dtype(torch::kInt64);
15-
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
19+
torch::Tensor idxs = torch::full({N, P1, K}, 0, long_opts);
1620
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
1721

1822
auto p1_a = p1.accessor<float, 3>();
1923
auto p2_a = p2.accessor<float, 3>();
24+
auto lengths1_a = lengths1.accessor<int64_t, 1>();
25+
auto lengths2_a = lengths2.accessor<int64_t, 1>();
2026
auto idxs_a = idxs.accessor<int64_t, 3>();
2127
auto dists_a = dists.accessor<float, 3>();
2228

2329
for (int n = 0; n < N; ++n) {
24-
for (int i1 = 0; i1 < P1; ++i1) {
30+
const int64_t length1 = lengths1_a[n];
31+
const int64_t length2 = lengths2_a[n];
32+
for (int64_t i1 = 0; i1 < length1; ++i1) {
2533
// Use a priority queue to store (distance, index) tuples.
2634
std::priority_queue<std::tuple<float, int>> q;
27-
for (int i2 = 0; i2 < P2; ++i2) {
35+
for (int64_t i2 = 0; i2 < length2; ++i2) {
2836
float dist = 0;
2937
for (int d = 0; d < D; ++d) {
3038
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];

0 commit comments

Comments
 (0)