Skip to content

Commit b2b0c5a

Browse files
gkioxarifacebook-github-bot
authored andcommitted
knn autograd
Summary: Adds knn backward to return `grad_pts1` and `grad_pts2`. Adds `knn_gather` to return the nearest neighbors in pts2. The BM tests include backward pass and are ran on an M40. ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- KNN_SQUARE_32_256_128_3_24_cpu 39558 43485 13 KNN_SQUARE_32_256_128_3_24_cuda:0 1080 1404 463 KNN_SQUARE_32_256_512_3_24_cpu 81950 85781 7 KNN_SQUARE_32_256_512_3_24_cuda:0 1519 1641 330 -------------------------------------------------------------------------------- Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- KNN_RAGGED_32_256_128_3_24_cpu 13798 14650 37 KNN_RAGGED_32_256_128_3_24_cuda:0 1576 1713 318 KNN_RAGGED_32_256_512_3_24_cpu 31255 32210 16 KNN_RAGGED_32_256_512_3_24_cuda:0 2024 2162 248 -------------------------------------------------------------------------------- ``` Reviewed By: jcjohnson Differential Revision: D20945556 fbshipit-source-id: a16f616029c6b5f8c2afceb5f2bc12c5c20d2f3c
1 parent 487d4d6 commit b2b0c5a

File tree

8 files changed

+555
-375
lines changed

8 files changed

+555
-375
lines changed

pytorch3d/csrc/ext.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
2020
m.def("packed_to_padded", &PackedToPadded);
2121
m.def("padded_to_packed", &PaddedToPacked);
2222
m.def("knn_points_idx", &KNearestNeighborIdx);
23+
m.def("knn_points_backward", &KNearestNeighborBackward);
2324
m.def("nn_points_idx", &NearestNeighborIdx);
2425
m.def("gather_scatter", &gather_scatter);
2526
m.def("rasterize_points", &RasterizePoints);

pytorch3d/csrc/knn/knn.cu

+90
Original file line numberDiff line numberDiff line change
@@ -412,3 +412,93 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
412412

413413
return std::make_tuple(idxs, dists);
414414
}
415+
416+
// ------------------------------------------------------------- //
417+
// Backward Operators //
418+
// ------------------------------------------------------------- //
419+
420+
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
421+
// Currently, support is for floats only.
422+
__global__ void KNearestNeighborBackwardKernel(
423+
const float* __restrict__ p1, // (N, P1, D)
424+
const float* __restrict__ p2, // (N, P2, D)
425+
const int64_t* __restrict__ lengths1, // (N,)
426+
const int64_t* __restrict__ lengths2, // (N,)
427+
const int64_t* __restrict__ idxs, // (N, P1, K)
428+
const float* __restrict__ grad_dists, // (N, P1, K)
429+
float* __restrict__ grad_p1, // (N, P1, D)
430+
float* __restrict__ grad_p2, // (N, P2, D)
431+
const size_t N,
432+
const size_t P1,
433+
const size_t P2,
434+
const size_t K,
435+
const size_t D) {
436+
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
437+
const size_t stride = gridDim.x * blockDim.x;
438+
439+
for (size_t i = tid; i < N * P1 * K * D; i += stride) {
440+
const size_t n = i / (P1 * K * D); // batch index
441+
size_t rem = i % (P1 * K * D);
442+
const size_t p1_idx = rem / (K * D); // index of point in p1
443+
rem = rem % (K * D);
444+
const size_t k = rem / D; // k-th nearest neighbor
445+
const size_t d = rem % D; // d-th dimension in the feature vector
446+
447+
const size_t num1 = lengths1[n]; // number of valid points in p1 in batch
448+
const size_t num2 = lengths2[n]; // number of valid points in p2 in batch
449+
if ((p1_idx < num1) && (k < num2)) {
450+
const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k];
451+
// index of point in p2 corresponding to the k-th nearest neighbor
452+
const size_t p2_idx = idxs[n * P1 * K + p1_idx * K + k];
453+
const float diff = 2.0 * grad_dist *
454+
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
455+
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
456+
atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff);
457+
}
458+
}
459+
}
460+
461+
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
462+
const at::Tensor& p1,
463+
const at::Tensor& p2,
464+
const at::Tensor& lengths1,
465+
const at::Tensor& lengths2,
466+
const at::Tensor& idxs,
467+
const at::Tensor& grad_dists) {
468+
const auto N = p1.size(0);
469+
const auto P1 = p1.size(1);
470+
const auto P2 = p2.size(1);
471+
const auto D = p2.size(2);
472+
const auto K = idxs.size(2);
473+
474+
AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension");
475+
AT_ASSERTM(idxs.size(0) == N, "KNN idxs must have the same batch dimension");
476+
AT_ASSERTM(
477+
idxs.size(1) == P1, "KNN idxs must have the same point dimension as p1");
478+
AT_ASSERTM(grad_dists.size(0) == N);
479+
AT_ASSERTM(grad_dists.size(1) == P1);
480+
AT_ASSERTM(grad_dists.size(2) == K);
481+
482+
auto grad_p1 = at::zeros({N, P1, D}, p1.options());
483+
auto grad_p2 = at::zeros({N, P2, D}, p2.options());
484+
485+
const int blocks = 64;
486+
const int threads = 512;
487+
488+
KNearestNeighborBackwardKernel<<<blocks, threads>>>(
489+
p1.data_ptr<float>(),
490+
p2.data_ptr<float>(),
491+
lengths1.data_ptr<int64_t>(),
492+
lengths2.data_ptr<int64_t>(),
493+
idxs.data_ptr<int64_t>(),
494+
grad_dists.data_ptr<float>(),
495+
grad_p1.data_ptr<float>(),
496+
grad_p2.data_ptr<float>(),
497+
N,
498+
P1,
499+
P2,
500+
K,
501+
D);
502+
503+
return std::make_tuple(grad_p1, grad_p2);
504+
}

pytorch3d/csrc/knn/knn.h

+63-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
1717
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
1818
// K: int giving the number of nearest points to return.
19-
// sorted: bool telling whether to sort the K returned points by their
20-
// distance.
2119
// version: Integer telling which implementation to use.
2220
//
2321
// Returns:
@@ -67,3 +65,66 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
6765
}
6866
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K);
6967
}
68+
69+
// Compute gradients with respect to p1 and p2
70+
//
71+
// Args:
72+
// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
73+
// containing P1 points of dimension D.
74+
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
75+
// containing P2 points of dimension D.
76+
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
77+
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
78+
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
79+
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
80+
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
81+
// It is padded with zeros so that it can be used easily in a later
82+
// gather() operation. This is computed from the forward pass.
83+
// grad_dists: FLoatTensor of shape (N, P1, K) which contains the input
84+
// gradients.
85+
//
86+
// Returns:
87+
// grad_p1: FloatTensor of shape (N, P1, D) containing the output gradients
88+
// wrt p1.
89+
// grad_p2: FloatTensor of shape (N, P2, D) containing the output gradients
90+
// wrt p2.
91+
92+
// CPU implementation.
93+
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
94+
const at::Tensor& p1,
95+
const at::Tensor& p2,
96+
const at::Tensor& lengths1,
97+
const at::Tensor& lengths2,
98+
const at::Tensor& idxs,
99+
const at::Tensor& grad_dists);
100+
101+
// CUDA implementation
102+
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
103+
const at::Tensor& p1,
104+
const at::Tensor& p2,
105+
const at::Tensor& lengths1,
106+
const at::Tensor& lengths2,
107+
const at::Tensor& idxs,
108+
const at::Tensor& grad_dists);
109+
110+
// Implementation which is exposed.
111+
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackward(
112+
const at::Tensor& p1,
113+
const at::Tensor& p2,
114+
const at::Tensor& lengths1,
115+
const at::Tensor& lengths2,
116+
const at::Tensor& idxs,
117+
const at::Tensor& grad_dists) {
118+
if (p1.is_cuda() || p2.is_cuda()) {
119+
#ifdef WITH_CUDA
120+
CHECK_CONTIGUOUS_CUDA(p1);
121+
CHECK_CONTIGUOUS_CUDA(p2);
122+
return KNearestNeighborBackwardCuda(
123+
p1, p2, lengths1, lengths2, idxs, grad_dists);
124+
#else
125+
AT_ERROR("Not compiled with GPU support.");
126+
#endif
127+
}
128+
return KNearestNeighborBackwardCpu(
129+
p1, p2, lengths1, lengths2, idxs, grad_dists);
130+
}

pytorch3d/csrc/knn/knn_cpu.cpp

+48
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,51 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
5757
}
5858
return std::make_tuple(idxs, dists);
5959
}
60+
61+
// ------------------------------------------------------------- //
62+
// Backward Operators //
63+
// ------------------------------------------------------------- //
64+
65+
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
66+
const at::Tensor& p1,
67+
const at::Tensor& p2,
68+
const at::Tensor& lengths1,
69+
const at::Tensor& lengths2,
70+
const at::Tensor& idxs,
71+
const at::Tensor& grad_dists) {
72+
const int N = p1.size(0);
73+
const int P1 = p1.size(1);
74+
const int D = p1.size(2);
75+
const int P2 = p2.size(1);
76+
const int K = idxs.size(2);
77+
78+
torch::Tensor grad_p1 = torch::full({N, P1, D}, 0, p1.options());
79+
torch::Tensor grad_p2 = torch::full({N, P2, D}, 0, p2.options());
80+
81+
auto p1_a = p1.accessor<float, 3>();
82+
auto p2_a = p2.accessor<float, 3>();
83+
auto lengths1_a = lengths1.accessor<int64_t, 1>();
84+
auto lengths2_a = lengths2.accessor<int64_t, 1>();
85+
auto idxs_a = idxs.accessor<int64_t, 3>();
86+
auto grad_dists_a = grad_dists.accessor<float, 3>();
87+
auto grad_p1_a = grad_p1.accessor<float, 3>();
88+
auto grad_p2_a = grad_p2.accessor<float, 3>();
89+
90+
for (int n = 0; n < N; ++n) {
91+
const int64_t length1 = lengths1_a[n];
92+
int64_t length2 = lengths2_a[n];
93+
length2 = (length2 < K) ? length2 : K;
94+
for (int64_t i1 = 0; i1 < length1; ++i1) {
95+
for (int64_t k = 0; k < length2; ++k) {
96+
const int64_t i2 = idxs_a[n][i1][k];
97+
for (int64_t d = 0; d < D; ++d) {
98+
const float diff =
99+
2.0f * grad_dists_a[n][i1][k] * (p1_a[n][i1][d] - p2_a[n][i2][d]);
100+
grad_p1_a[n][i1][d] += diff;
101+
grad_p2_a[n][i2][d] += -1.0f * diff;
102+
}
103+
}
104+
}
105+
}
106+
return std::make_tuple(grad_p1, grad_p2);
107+
}

pytorch3d/ops/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from .cubify import cubify
55
from .graph_conv import GraphConv
6+
from .knn import knn_gather, knn_points
67
from .mesh_face_areas_normals import mesh_face_areas_normals
78
from .nearest_neighbor_points import nn_points_idx
89
from .packed_to_padded import packed_to_padded, padded_to_packed

0 commit comments

Comments
 (0)