8
8
#include " dispatch.cuh"
9
9
#include " mink.cuh"
10
10
11
+ // A chunk of work is blocksize-many points of P1.
12
+ // The number of potential chunks to do is N*(1+(P1-1)/blocksize)
13
+ // call (1+(P1-1)/blocksize) chunks_per_cloud
14
+ // These chunks are divided among the gridSize-many blocks.
15
+ // In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
16
+ // In chunk i, we work on cloud i/chunks_per_cloud on points starting from
17
+ // blocksize*(i%chunks_per_cloud).
18
+
11
19
template <typename scalar_t >
12
20
__global__ void KNearestNeighborKernelV0 (
13
21
const scalar_t * __restrict__ points1,
14
22
const scalar_t * __restrict__ points2,
23
+ const int64_t * __restrict__ lengths1,
24
+ const int64_t * __restrict__ lengths2,
15
25
scalar_t * __restrict__ dists,
16
26
int64_t * __restrict__ idxs,
17
27
const size_t N,
18
28
const size_t P1,
19
29
const size_t P2,
20
30
const size_t D,
21
31
const size_t K) {
22
- // Stupid version: Make each thread handle one query point and loop over
23
- // all P2 target points. There are N * P1 input points to handle, so
24
- // do a trivial parallelization over threads.
25
32
// Store both dists and indices for knn in global memory.
26
- const int tid = threadIdx .x + blockIdx .x * blockDim .x ;
27
- const int num_threads = blockDim .x * gridDim .x ;
28
- for (int np = tid; np < N * P1; np += num_threads) {
29
- int n = np / P1;
30
- int p1 = np % P1;
33
+ const int64_t chunks_per_cloud = (1 + (P1 - 1 ) / blockDim .x );
34
+ const int64_t chunks_to_do = N * chunks_per_cloud;
35
+ for (int64_t chunk = blockIdx .x ; chunk < chunks_to_do; chunk += gridDim .x ) {
36
+ const int64_t n = chunk / chunks_per_cloud;
37
+ const int64_t start_point = blockDim .x * (chunk % chunks_per_cloud);
38
+ int64_t p1 = start_point + threadIdx .x ;
39
+ if (p1 >= lengths1[n])
40
+ continue ;
31
41
int offset = n * P1 * K + p1 * K;
42
+ int64_t length2 = lengths2[n];
32
43
MinK<scalar_t , int64_t > mink (dists + offset, idxs + offset, K);
33
- for (int p2 = 0 ; p2 < P2 ; ++p2) {
44
+ for (int p2 = 0 ; p2 < length2 ; ++p2) {
34
45
// Find the distance between points1[n, p1] and points[n, p2]
35
46
scalar_t dist = 0 ;
36
47
for (int d = 0 ; d < D; ++d) {
@@ -48,6 +59,8 @@ template <typename scalar_t, int64_t D>
48
59
__global__ void KNearestNeighborKernelV1 (
49
60
const scalar_t * __restrict__ points1,
50
61
const scalar_t * __restrict__ points2,
62
+ const int64_t * __restrict__ lengths1,
63
+ const int64_t * __restrict__ lengths2,
51
64
scalar_t * __restrict__ dists,
52
65
int64_t * __restrict__ idxs,
53
66
const size_t N,
@@ -58,18 +71,22 @@ __global__ void KNearestNeighborKernelV1(
58
71
// so we can cache the current point in a thread-local array. We still store
59
72
// the current best K dists and indices in global memory, so this should work
60
73
// for very large K and fairly large D.
61
- const int tid = threadIdx .x + blockIdx .x * blockDim .x ;
62
- const int num_threads = blockDim .x * gridDim .x ;
63
74
scalar_t cur_point[D];
64
- for (int np = tid; np < N * P1; np += num_threads) {
65
- int n = np / P1;
66
- int p1 = np % P1;
75
+ const int64_t chunks_per_cloud = (1 + (P1 - 1 ) / blockDim .x );
76
+ const int64_t chunks_to_do = N * chunks_per_cloud;
77
+ for (int64_t chunk = blockIdx .x ; chunk < chunks_to_do; chunk += gridDim .x ) {
78
+ const int64_t n = chunk / chunks_per_cloud;
79
+ const int64_t start_point = blockDim .x * (chunk % chunks_per_cloud);
80
+ int64_t p1 = start_point + threadIdx .x ;
81
+ if (p1 >= lengths1[n])
82
+ continue ;
67
83
for (int d = 0 ; d < D; ++d) {
68
84
cur_point[d] = points1[n * P1 * D + p1 * D + d];
69
85
}
70
86
int offset = n * P1 * K + p1 * K;
87
+ int64_t length2 = lengths2[n];
71
88
MinK<scalar_t , int64_t > mink (dists + offset, idxs + offset, K);
72
- for (int p2 = 0 ; p2 < P2 ; ++p2) {
89
+ for (int p2 = 0 ; p2 < length2 ; ++p2) {
73
90
// Find the distance between cur_point and points[n, p2]
74
91
scalar_t dist = 0 ;
75
92
for (int d = 0 ; d < D; ++d) {
@@ -89,40 +106,48 @@ struct KNearestNeighborV1Functor {
89
106
size_t threads,
90
107
const scalar_t * __restrict__ points1,
91
108
const scalar_t * __restrict__ points2,
109
+ const int64_t * __restrict__ lengths1,
110
+ const int64_t * __restrict__ lengths2,
92
111
scalar_t * __restrict__ dists,
93
112
int64_t * __restrict__ idxs,
94
113
const size_t N,
95
114
const size_t P1,
96
115
const size_t P2,
97
116
const size_t K) {
98
- KNearestNeighborKernelV1<scalar_t , D>
99
- <<<blocks, threads>>> ( points1, points2, dists, idxs, N, P1, P2, K);
117
+ KNearestNeighborKernelV1<scalar_t , D><<<blocks, threads>>> (
118
+ points1, points2, lengths1, lengths2 , dists, idxs, N, P1, P2, K);
100
119
}
101
120
};
102
121
103
122
template <typename scalar_t , int64_t D, int64_t K>
104
123
__global__ void KNearestNeighborKernelV2 (
105
124
const scalar_t * __restrict__ points1,
106
125
const scalar_t * __restrict__ points2,
126
+ const int64_t * __restrict__ lengths1,
127
+ const int64_t * __restrict__ lengths2,
107
128
scalar_t * __restrict__ dists,
108
129
int64_t * __restrict__ idxs,
109
130
const int64_t N,
110
131
const int64_t P1,
111
132
const int64_t P2) {
112
133
// Same general implementation as V2, but also hoist K into a template arg.
113
- const int tid = threadIdx .x + blockIdx .x * blockDim .x ;
114
- const int num_threads = blockDim .x * gridDim .x ;
115
134
scalar_t cur_point[D];
116
135
scalar_t min_dists[K];
117
136
int min_idxs[K];
118
- for (int np = tid; np < N * P1; np += num_threads) {
119
- int n = np / P1;
120
- int p1 = np % P1;
137
+ const int64_t chunks_per_cloud = (1 + (P1 - 1 ) / blockDim .x );
138
+ const int64_t chunks_to_do = N * chunks_per_cloud;
139
+ for (int64_t chunk = blockIdx .x ; chunk < chunks_to_do; chunk += gridDim .x ) {
140
+ const int64_t n = chunk / chunks_per_cloud;
141
+ const int64_t start_point = blockDim .x * (chunk % chunks_per_cloud);
142
+ int64_t p1 = start_point + threadIdx .x ;
143
+ if (p1 >= lengths1[n])
144
+ continue ;
121
145
for (int d = 0 ; d < D; ++d) {
122
146
cur_point[d] = points1[n * P1 * D + p1 * D + d];
123
147
}
148
+ int64_t length2 = lengths2[n];
124
149
MinK<scalar_t , int > mink (min_dists, min_idxs, K);
125
- for (int p2 = 0 ; p2 < P2 ; ++p2) {
150
+ for (int p2 = 0 ; p2 < length2 ; ++p2) {
126
151
scalar_t dist = 0 ;
127
152
for (int d = 0 ; d < D; ++d) {
128
153
int offset = n * P2 * D + p2 * D + d;
@@ -146,20 +171,24 @@ struct KNearestNeighborKernelV2Functor {
146
171
size_t threads,
147
172
const scalar_t * __restrict__ points1,
148
173
const scalar_t * __restrict__ points2,
174
+ const int64_t * __restrict__ lengths1,
175
+ const int64_t * __restrict__ lengths2,
149
176
scalar_t * __restrict__ dists,
150
177
int64_t * __restrict__ idxs,
151
178
const int64_t N,
152
179
const int64_t P1,
153
180
const int64_t P2) {
154
- KNearestNeighborKernelV2<scalar_t , D, K>
155
- <<<blocks, threads>>> ( points1, points2, dists, idxs, N, P1, P2);
181
+ KNearestNeighborKernelV2<scalar_t , D, K><<<blocks, threads>>> (
182
+ points1, points2, lengths1, lengths2 , dists, idxs, N, P1, P2);
156
183
}
157
184
};
158
185
159
186
template <typename scalar_t , int D, int K>
160
187
__global__ void KNearestNeighborKernelV3 (
161
188
const scalar_t * __restrict__ points1,
162
189
const scalar_t * __restrict__ points2,
190
+ const int64_t * __restrict__ lengths1,
191
+ const int64_t * __restrict__ lengths2,
163
192
scalar_t * __restrict__ dists,
164
193
int64_t * __restrict__ idxs,
165
194
const size_t N,
@@ -169,19 +198,23 @@ __global__ void KNearestNeighborKernelV3(
169
198
// Enabling sorting for this version leads to huge slowdowns; I suspect
170
199
// that it forces min_dists into local memory rather than registers.
171
200
// As a result this version is always unsorted.
172
- const int tid = threadIdx .x + blockIdx .x * blockDim .x ;
173
- const int num_threads = blockDim .x * gridDim .x ;
174
201
scalar_t cur_point[D];
175
202
scalar_t min_dists[K];
176
203
int min_idxs[K];
177
- for (int np = tid; np < N * P1; np += num_threads) {
178
- int n = np / P1;
179
- int p1 = np % P1;
204
+ const int64_t chunks_per_cloud = (1 + (P1 - 1 ) / blockDim .x );
205
+ const int64_t chunks_to_do = N * chunks_per_cloud;
206
+ for (int64_t chunk = blockIdx .x ; chunk < chunks_to_do; chunk += gridDim .x ) {
207
+ const int64_t n = chunk / chunks_per_cloud;
208
+ const int64_t start_point = blockDim .x * (chunk % chunks_per_cloud);
209
+ int64_t p1 = start_point + threadIdx .x ;
210
+ if (p1 >= lengths1[n])
211
+ continue ;
180
212
for (int d = 0 ; d < D; ++d) {
181
213
cur_point[d] = points1[n * P1 * D + p1 * D + d];
182
214
}
215
+ int64_t length2 = lengths2[n];
183
216
RegisterMinK<scalar_t , int , K> mink (min_dists, min_idxs);
184
- for (int p2 = 0 ; p2 < P2 ; ++p2) {
217
+ for (int p2 = 0 ; p2 < length2 ; ++p2) {
185
218
scalar_t dist = 0 ;
186
219
for (int d = 0 ; d < D; ++d) {
187
220
int offset = n * P2 * D + p2 * D + d;
@@ -205,13 +238,15 @@ struct KNearestNeighborKernelV3Functor {
205
238
size_t threads,
206
239
const scalar_t * __restrict__ points1,
207
240
const scalar_t * __restrict__ points2,
241
+ const int64_t * __restrict__ lengths1,
242
+ const int64_t * __restrict__ lengths2,
208
243
scalar_t * __restrict__ dists,
209
244
int64_t * __restrict__ idxs,
210
245
const size_t N,
211
246
const size_t P1,
212
247
const size_t P2) {
213
- KNearestNeighborKernelV3<scalar_t , D, K>
214
- <<<blocks, threads>>> ( points1, points2, dists, idxs, N, P1, P2);
248
+ KNearestNeighborKernelV3<scalar_t , D, K><<<blocks, threads>>> (
249
+ points1, points2, lengths1, lengths2 , dists, idxs, N, P1, P2);
215
250
}
216
251
};
217
252
@@ -257,6 +292,8 @@ int ChooseVersion(const int64_t D, const int64_t K) {
257
292
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda (
258
293
const at::Tensor& p1,
259
294
const at::Tensor& p2,
295
+ const at::Tensor& lengths1,
296
+ const at::Tensor& lengths2,
260
297
int K,
261
298
int version) {
262
299
const auto N = p1.size (0 );
@@ -267,8 +304,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
267
304
268
305
AT_ASSERTM (p2.size (2 ) == D, " Point sets must have the same last dimension" );
269
306
auto long_dtype = p1.options ().dtype (at::kLong );
270
- auto idxs = at::full ({N, P1, K}, - 1 , long_dtype);
271
- auto dists = at::full ({N, P1, K}, - 1 , p1.options ());
307
+ auto idxs = at::zeros ({N, P1, K}, long_dtype);
308
+ auto dists = at::zeros ({N, P1, K}, p1.options ());
272
309
273
310
if (version < 0 ) {
274
311
version = ChooseVersion (D, K);
@@ -294,6 +331,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
294
331
<<<blocks, threads>>> (
295
332
p1.data_ptr <scalar_t >(),
296
333
p2.data_ptr <scalar_t >(),
334
+ lengths1.data_ptr <int64_t >(),
335
+ lengths2.data_ptr <int64_t >(),
297
336
dists.data_ptr <scalar_t >(),
298
337
idxs.data_ptr <int64_t >(),
299
338
N,
@@ -314,6 +353,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
314
353
threads,
315
354
p1.data_ptr <scalar_t >(),
316
355
p2.data_ptr <scalar_t >(),
356
+ lengths1.data_ptr <int64_t >(),
357
+ lengths2.data_ptr <int64_t >(),
317
358
dists.data_ptr <scalar_t >(),
318
359
idxs.data_ptr <int64_t >(),
319
360
N,
@@ -336,6 +377,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
336
377
threads,
337
378
p1.data_ptr <scalar_t >(),
338
379
p2.data_ptr <scalar_t >(),
380
+ lengths1.data_ptr <int64_t >(),
381
+ lengths2.data_ptr <int64_t >(),
339
382
dists.data_ptr <scalar_t >(),
340
383
idxs.data_ptr <int64_t >(),
341
384
N,
@@ -357,6 +400,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
357
400
threads,
358
401
p1.data_ptr <scalar_t >(),
359
402
p2.data_ptr <scalar_t >(),
403
+ lengths1.data_ptr <int64_t >(),
404
+ lengths2.data_ptr <int64_t >(),
360
405
dists.data_ptr <scalar_t >(),
361
406
idxs.data_ptr <int64_t >(),
362
407
N,
0 commit comments