@@ -412,3 +412,93 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
412
412
413
413
return std::make_tuple (idxs, dists);
414
414
}
415
+
416
+ // ------------------------------------------------------------- //
417
+ // Backward Operators //
418
+ // ------------------------------------------------------------- //
419
+
420
+ // TODO(gkioxari) support all data types once AtomicAdd supports doubles.
421
+ // Currently, support is for floats only.
422
+ __global__ void KNearestNeighborBackwardKernel (
423
+ const float * __restrict__ p1, // (N, P1, D)
424
+ const float * __restrict__ p2, // (N, P2, D)
425
+ const int64_t * __restrict__ lengths1, // (N,)
426
+ const int64_t * __restrict__ lengths2, // (N,)
427
+ const int64_t * __restrict__ idxs, // (N, P1, K)
428
+ const float * __restrict__ grad_dists, // (N, P1, K)
429
+ float * __restrict__ grad_p1, // (N, P1, D)
430
+ float * __restrict__ grad_p2, // (N, P2, D)
431
+ const size_t N,
432
+ const size_t P1,
433
+ const size_t P2,
434
+ const size_t K,
435
+ const size_t D) {
436
+ const size_t tid = blockIdx .x * blockDim .x + threadIdx .x ;
437
+ const size_t stride = gridDim .x * blockDim .x ;
438
+
439
+ for (size_t i = tid; i < N * P1 * K * D; i += stride) {
440
+ const size_t n = i / (P1 * K * D); // batch index
441
+ size_t rem = i % (P1 * K * D);
442
+ const size_t p1_idx = rem / (K * D); // index of point in p1
443
+ rem = rem % (K * D);
444
+ const size_t k = rem / D; // k-th nearest neighbor
445
+ const size_t d = rem % D; // d-th dimension in the feature vector
446
+
447
+ const size_t num1 = lengths1[n]; // number of valid points in p1 in batch
448
+ const size_t num2 = lengths2[n]; // number of valid points in p2 in batch
449
+ if ((p1_idx < num1) && (k < num2)) {
450
+ const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k];
451
+ // index of point in p2 corresponding to the k-th nearest neighbor
452
+ const size_t p2_idx = idxs[n * P1 * K + p1_idx * K + k];
453
+ const float diff = 2.0 * grad_dist *
454
+ (p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
455
+ atomicAdd (grad_p1 + n * P1 * D + p1_idx * D + d, diff);
456
+ atomicAdd (grad_p2 + n * P2 * D + p2_idx * D + d, -1 .0f * diff);
457
+ }
458
+ }
459
+ }
460
+
461
+ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda (
462
+ const at::Tensor& p1,
463
+ const at::Tensor& p2,
464
+ const at::Tensor& lengths1,
465
+ const at::Tensor& lengths2,
466
+ const at::Tensor& idxs,
467
+ const at::Tensor& grad_dists) {
468
+ const auto N = p1.size (0 );
469
+ const auto P1 = p1.size (1 );
470
+ const auto P2 = p2.size (1 );
471
+ const auto D = p2.size (2 );
472
+ const auto K = idxs.size (2 );
473
+
474
+ AT_ASSERTM (p2.size (2 ) == D, " Point sets must have the same last dimension" );
475
+ AT_ASSERTM (idxs.size (0 ) == N, " KNN idxs must have the same batch dimension" );
476
+ AT_ASSERTM (
477
+ idxs.size (1 ) == P1, " KNN idxs must have the same point dimension as p1" );
478
+ AT_ASSERTM (grad_dists.size (0 ) == N);
479
+ AT_ASSERTM (grad_dists.size (1 ) == P1);
480
+ AT_ASSERTM (grad_dists.size (2 ) == K);
481
+
482
+ auto grad_p1 = at::zeros ({N, P1, D}, p1.options ());
483
+ auto grad_p2 = at::zeros ({N, P2, D}, p2.options ());
484
+
485
+ const int blocks = 64 ;
486
+ const int threads = 512 ;
487
+
488
+ KNearestNeighborBackwardKernel<<<blocks, threads>>> (
489
+ p1.data_ptr <float >(),
490
+ p2.data_ptr <float >(),
491
+ lengths1.data_ptr <int64_t >(),
492
+ lengths2.data_ptr <int64_t >(),
493
+ idxs.data_ptr <int64_t >(),
494
+ grad_dists.data_ptr <float >(),
495
+ grad_p1.data_ptr <float >(),
496
+ grad_p2.data_ptr <float >(),
497
+ N,
498
+ P1,
499
+ P2,
500
+ K,
501
+ D);
502
+
503
+ return std::make_tuple (grad_p1, grad_p2);
504
+ }
0 commit comments