Skip to content

Commit 870290d

Browse files
jcjohnsonfacebook-github-bot
authored andcommitted
Implement K-Nearest Neighbors
Summary: Implements K-Nearest Neighbors with C++ and CUDA versions. KNN in CUDA is highly nontrivial. I've implemented a few different versions of the kernel, and we heuristically dispatch to different kernels based on the problem size. Some of the kernels rely on template specialization on either D or K, so we use template metaprogramming to compile specialized versions for ranges of D and K. These kernels are up to 3x faster than our existing 1-nearest-neighbor kernels, so we should also consider swapping out `nn_points_idx` to use these kernels in the backend. I've been working mostly on the CUDA kernels, and haven't converged on the correct Python API. I still want to benchmark against FAISS to see how far away we are from their performance. Reviewed By: bottler Differential Revision: D19729286 fbshipit-source-id: 608ffbb7030c21fe4008f330522f4890f0c3c21a
1 parent 02d4968 commit 870290d

File tree

12 files changed

+1328
-1
lines changed

12 files changed

+1328
-1
lines changed

pytorch3d/csrc/dispatch.cuh

+261
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
//
3+
// This file provides utilities for dispatching to specialized versions of functions.
4+
// This is especially useful for CUDA kernels, since specializing them to particular
5+
// input sizes can often allow the compiler to unroll loops and place arrays into
6+
// registers, which can give huge performance speedups.
7+
//
8+
// As an example, suppose we have the following function which is specialized
9+
// based on a compile-time int64_t value:
10+
//
11+
// template<typename T, int64_t x>
12+
// struct SquareOffset {
13+
// static void run(T y) {
14+
// T val = x * x + y;
15+
// std::cout << val << std::endl;
16+
// }
17+
// }
18+
//
19+
// This function takes one compile-time argument x, and one run-time argument y.
20+
// We might want to compile specialized versions of this for x=0, x=1, etc and
21+
// then dispatch to the correct one based on the runtime value of x.
22+
// One simple way to achieve this is with a lookup table:
23+
//
24+
// template<typename T>
25+
// void DispatchSquareOffset(const int64_t x, T y) {
26+
// if (x == 0) {
27+
// SquareOffset<T, 0>::run(y);
28+
// } else if (x == 1) {
29+
// SquareOffset<T, 1>::run(y);
30+
// } else if (x == 2) {
31+
// SquareOffset<T, 2>::run(y);
32+
// }
33+
// }
34+
//
35+
// This function takes both x and y as run-time arguments, and dispatches to
36+
// different specialized versions of SquareOffset based on the run-time value
37+
// of x. This works, but it's tedious and error-prone. If we want to change the
38+
// set of x values for which we provide compile-time specializations, then we
39+
// will need to do a lot of tedius editing of the dispatch function. Also, if we
40+
// want to provide compile-time specializations for another function other than
41+
// SquareOffset, we will need to duplicate the entire lookup table.
42+
//
43+
// To solve these problems, we can use the DispatchKernel1D function provided by
44+
// this file instead:
45+
//
46+
// template<typename T>
47+
// void DispatchSquareOffset(const int64_t x, T y) {
48+
// constexpr int64_t xmin = 0;
49+
// constexpr int64_t xmax = 2;
50+
// DispatchKernel1D<SquareOffset, T, xmin, xmax>(x, y);
51+
// }
52+
//
53+
// DispatchKernel1D uses template metaprogramming to compile specialized
54+
// versions of SquareOffset for all values of x with xmin <= x <= xmax, and
55+
// then dispatches to the correct one based on the run-time value of x. If we
56+
// want to change the range of x values for which SquareOffset is specialized
57+
// at compile-time, then all we have to do is change the values of the
58+
// compile-time constants xmin and xmax.
59+
//
60+
// This file also allows us to similarly dispatch functions that depend on two
61+
// compile-time int64_t values, using the DispatchKernel2D function like this:
62+
//
63+
// template<typename T, int64_t x, int64_t y>
64+
// struct Sum {
65+
// static void run(T z, T w) {
66+
// T val = x + y + z + w;
67+
// std::cout << val << std::endl;
68+
// }
69+
// }
70+
//
71+
// template<typename T>
72+
// void DispatchSum(const int64_t x, const int64_t y, int z, int w) {
73+
// constexpr int64_t xmin = 1;
74+
// constexpr int64_t xmax = 3;
75+
// constexpr int64_t ymin = 2;
76+
// constexpr int64_t ymax = 5;
77+
// DispatchKernel2D<Sum, T, xmin, xmax, ymin, ymax>(x, y, z, w);
78+
// }
79+
//
80+
// Like its 1D counterpart, DispatchKernel2D uses template metaprogramming to
81+
// compile specialized versions of sum for all values of (x, y) with
82+
// xmin <= x <= xmax and ymin <= y <= ymax, then dispatches to the correct
83+
// specialized version based on the runtime values of x and y.
84+
85+
// Define some helper structs in an anonymous namespace.
86+
namespace {
87+
88+
// 1D dispatch: general case.
89+
// Kernel is the function we want to dispatch to; it should take a typename and
90+
// an int64_t as template args, and it should define a static void function
91+
// run which takes any number of arguments of any type.
92+
// In order to dispatch, we will take an additional template argument curN,
93+
// and increment it via template recursion until it is equal to the run-time
94+
// argument N.
95+
template<
96+
template<typename, int64_t> class Kernel,
97+
typename T,
98+
int64_t minN,
99+
int64_t maxN,
100+
int64_t curN,
101+
typename... Args
102+
>
103+
struct DispatchKernelHelper1D {
104+
static void run(const int64_t N, Args... args) {
105+
if (curN == N) {
106+
// The compile-time value curN is equal to the run-time value N, so we
107+
// can dispatch to the run method of the Kernel.
108+
Kernel<T, curN>::run(args...);
109+
} else if (curN < N) {
110+
// Increment curN via template recursion
111+
DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run(N, args...);
112+
}
113+
// We shouldn't get here -- throw an error?
114+
}
115+
};
116+
117+
118+
// 1D dispatch: Specialization when curN == maxN
119+
// We need this base case to avoid infinite template recursion.
120+
template<
121+
template<typename, int64_t> class Kernel,
122+
typename T,
123+
int64_t minN,
124+
int64_t maxN,
125+
typename... Args
126+
>
127+
struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
128+
static void run(const int64_t N, Args... args) {
129+
if (N == maxN) {
130+
Kernel<T, maxN>::run(args...);
131+
}
132+
// We shouldn't get here -- throw an error?
133+
}
134+
};
135+
136+
137+
// 2D dispatch, general case.
138+
// This is similar to the 1D case: we take additional template args curN and
139+
// curM, and increment them via template recursion until they are equal to
140+
// the run-time values of N and M, at which point we dispatch to the run
141+
// method of the kernel.
142+
template<
143+
template<typename, int64_t, int64_t> class Kernel,
144+
typename T,
145+
int64_t minN, int64_t maxN, int64_t curN,
146+
int64_t minM, int64_t maxM, int64_t curM,
147+
typename... Args
148+
>
149+
struct DispatchKernelHelper2D {
150+
static void run(const int64_t N, const int64_t M, Args... args) {
151+
if (curN == N && curM == M) {
152+
Kernel<T, curN, curM>::run(args...);
153+
} else if (curN < N && curM < M) {
154+
// Increment both curN and curM. This isn't strictly necessary; we could
155+
// just increment one or the other at each step. But this helps to cut
156+
// on the number of recursive calls we make.
157+
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM + 1, Args...>::run(N, M, args...);
158+
} else if (curN < N) {
159+
// Increment curN only
160+
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM, Args...>::run(N, M, args...);
161+
} else if (curM < M) {
162+
// Increment curM only
163+
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, curM + 1, Args...>::run(N, M, args...);
164+
}
165+
}
166+
};
167+
168+
169+
// 2D dispatch, specialization for curN == maxN
170+
template<
171+
template<typename, int64_t, int64_t> class Kernel,
172+
typename T,
173+
int64_t minN, int64_t maxN,
174+
int64_t minM, int64_t maxM, int64_t curM,
175+
typename... Args
176+
>
177+
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM, Args...> {
178+
static void run(const int64_t N, const int64_t M, Args... args) {
179+
if (maxN == N && curM == M) {
180+
Kernel<T, maxN, curM>::run(args...);
181+
} else if (curM < maxM) {
182+
DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM + 1, Args...>::run(N, M, args...);
183+
}
184+
// We should not get here -- throw an error?
185+
}
186+
};
187+
188+
189+
// 2D dispatch, specialization for curM == maxM
190+
template<
191+
template<typename, int64_t, int64_t> class Kernel,
192+
typename T,
193+
int64_t minN, int64_t maxN, int64_t curN,
194+
int64_t minM, int64_t maxM,
195+
typename... Args
196+
>
197+
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, maxM, Args...> {
198+
static void run(const int64_t N, const int64_t M, Args... args) {
199+
if (curN == N && maxM == M) {
200+
Kernel<T, curN, maxM>::run(args...);
201+
} else if (curN < maxN) {
202+
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, maxM, Args...>::run(N, M, args...);
203+
}
204+
// We should not get here -- throw an error?
205+
}
206+
};
207+
208+
209+
// 2D dispatch, specialization for curN == maxN, curM == maxM
210+
template<
211+
template<typename, int64_t, int64_t> class Kernel,
212+
typename T,
213+
int64_t minN, int64_t maxN,
214+
int64_t minM, int64_t maxM,
215+
typename... Args
216+
>
217+
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, maxM, Args...> {
218+
static void run(const int64_t N, const int64_t M, Args... args) {
219+
if (maxN == N && maxM == M) {
220+
Kernel<T, maxN, maxM>::run(args...);
221+
}
222+
// We should not get here -- throw an error?
223+
}
224+
};
225+
226+
} // namespace
227+
228+
229+
// This is the function we expect users to call to dispatch to 1D functions
230+
template<
231+
template<typename, int64_t> class Kernel,
232+
typename T,
233+
int64_t minN,
234+
int64_t maxN,
235+
typename... Args
236+
>
237+
void DispatchKernel1D(const int64_t N, Args... args) {
238+
if (minN <= N && N <= maxN) {
239+
// Kick off the template recursion by calling the Helper with curN = minN
240+
DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run(N, args...);
241+
}
242+
// Maybe throw an error if we tried to dispatch outside the allowed range?
243+
}
244+
245+
246+
// This is the function we expect users to call to dispatch to 2D functions
247+
template<
248+
template<typename, int64_t, int64_t> class Kernel,
249+
typename T,
250+
int64_t minN, int64_t maxN,
251+
int64_t minM, int64_t maxM,
252+
typename... Args
253+
>
254+
void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) {
255+
if (minN <= N && N <= maxN && minM <= M && M <= maxM) {
256+
// Kick off the template recursion by calling the Helper with curN = minN
257+
// and curM = minM
258+
DispatchKernelHelper2D<Kernel, T, minN, maxN, minN, minM, maxM, minM, Args...>::run(N, M, args...);
259+
}
260+
// Maybe throw an error if we tried to dispatch outside the specified range?
261+
}

pytorch3d/csrc/ext.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "compositing/weighted_sum.h"
77
#include "face_areas_normals/face_areas_normals.h"
88
#include "gather_scatter/gather_scatter.h"
9+
#include "knn/knn.h"
910
#include "nearest_neighbor_points/nearest_neighbor_points.h"
1011
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
1112
#include "rasterize_meshes/rasterize_meshes.h"
@@ -16,6 +17,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1617
m.def("face_areas_normals_backward", &FaceAreasNormalsBackward);
1718
m.def("packed_to_padded", &PackedToPadded);
1819
m.def("padded_to_packed", &PaddedToPacked);
20+
m.def("knn_points_idx", &KNearestNeighborIdx);
1921
m.def("nn_points_idx", &NearestNeighborIdx);
2022
m.def("gather_scatter", &gather_scatter);
2123
m.def("rasterize_points", &RasterizePoints);

pytorch3d/csrc/index_utils.cuh

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
// This converts dynamic array lookups into static array lookups, for small
4+
// arrays up to size 32.
5+
//
6+
// Suppose we have a small thread-local array:
7+
//
8+
// float vals[10];
9+
//
10+
// Ideally we should only index this array using static indices:
11+
//
12+
// for (int i = 0; i < 10; ++i) vals[i] = i * i;
13+
//
14+
// If we do so, then the CUDA compiler may be able to place the array into
15+
// registers, which can have a big performance improvement. However if we
16+
// access the array dynamically, the the compiler may force the array into
17+
// local memory, which has the same latency as global memory.
18+
//
19+
// These functions convert dynamic array access into static array access
20+
// using a brute-force lookup table. It can be used like this:
21+
//
22+
// float vals[10];
23+
// int idx = 3;
24+
// float val = 3.14f;
25+
// RegisterIndexUtils<float, 10>::set(vals, idx, val);
26+
// float val2 = RegisterIndexUtils<float, 10>::get(vals, idx);
27+
//
28+
// The implementation is based on fbcuda/RegisterUtils.cuh:
29+
// https://github.com/facebook/fbcuda/blob/master/RegisterUtils.cuh
30+
// To avoid depending on the entire library, we just reimplement these two
31+
// functions. The fbcuda implementation is a bit more sophisticated, and uses
32+
// the preprocessor to generate switch statements that go up to N for each
33+
// value of N. We are lazy and just have a giant explicit switch statement.
34+
//
35+
// We might be able to use a template metaprogramming approach similar to
36+
// DispatchKernel1D for this. However DispatchKernel1D is intended to be used
37+
// for dispatching to the correct CUDA kernel on the host, while this is
38+
// is intended to run on the device. I was concerned that a metaprogramming
39+
// approach for this might lead to extra function calls at runtime if the
40+
// compiler fails to optimize them away, which could be very slow on device.
41+
// However I didn't actually benchmark or test this.
42+
template<typename T, int N>
43+
struct RegisterIndexUtils {
44+
__device__ __forceinline__ static T get(const T arr[N], int idx) {
45+
if (idx < 0 || idx >= N) return T();
46+
switch (idx) {
47+
case 0: return arr[0];
48+
case 1: return arr[1];
49+
case 2: return arr[2];
50+
case 3: return arr[3];
51+
case 4: return arr[4];
52+
case 5: return arr[5];
53+
case 6: return arr[6];
54+
case 7: return arr[7];
55+
case 8: return arr[8];
56+
case 9: return arr[9];
57+
case 10: return arr[10];
58+
case 11: return arr[11];
59+
case 12: return arr[12];
60+
case 13: return arr[13];
61+
case 14: return arr[14];
62+
case 15: return arr[15];
63+
case 16: return arr[16];
64+
case 17: return arr[17];
65+
case 18: return arr[18];
66+
case 19: return arr[19];
67+
case 20: return arr[20];
68+
case 21: return arr[21];
69+
case 22: return arr[22];
70+
case 23: return arr[23];
71+
case 24: return arr[24];
72+
case 25: return arr[25];
73+
case 26: return arr[26];
74+
case 27: return arr[27];
75+
case 28: return arr[28];
76+
case 29: return arr[29];
77+
case 30: return arr[30];
78+
case 31: return arr[31];
79+
};
80+
return T();
81+
}
82+
83+
__device__ __forceinline__ static void set(T arr[N], int idx, T val) {
84+
if (idx < 0 || idx >= N) return;
85+
switch (idx) {
86+
case 0: arr[0] = val; break;
87+
case 1: arr[1] = val; break;
88+
case 2: arr[2] = val; break;
89+
case 3: arr[3] = val; break;
90+
case 4: arr[4] = val; break;
91+
case 5: arr[5] = val; break;
92+
case 6: arr[6] = val; break;
93+
case 7: arr[7] = val; break;
94+
case 8: arr[8] = val; break;
95+
case 9: arr[9] = val; break;
96+
case 10: arr[10] = val; break;
97+
case 11: arr[11] = val; break;
98+
case 12: arr[12] = val; break;
99+
case 13: arr[13] = val; break;
100+
case 14: arr[14] = val; break;
101+
case 15: arr[15] = val; break;
102+
case 16: arr[16] = val; break;
103+
case 17: arr[17] = val; break;
104+
case 18: arr[18] = val; break;
105+
case 19: arr[19] = val; break;
106+
case 20: arr[20] = val; break;
107+
case 21: arr[21] = val; break;
108+
case 22: arr[22] = val; break;
109+
case 23: arr[23] = val; break;
110+
case 24: arr[24] = val; break;
111+
case 25: arr[25] = val; break;
112+
case 26: arr[26] = val; break;
113+
case 27: arr[27] = val; break;
114+
case 28: arr[28] = val; break;
115+
case 29: arr[29] = val; break;
116+
case 30: arr[30] = val; break;
117+
case 31: arr[31] = val; break;
118+
}
119+
}
120+
};

0 commit comments

Comments
 (0)