|
| 1 | +/* |
| 2 | + * Copyright (c) Facebook, Inc. and its affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include <ATen/ATen.h> |
| 10 | +#include <ATen/cuda/CUDAContext.h> |
| 11 | +#include <c10/cuda/CUDAGuard.h> |
| 12 | +#include <torch/extension.h> |
| 13 | + |
| 14 | +using torch::PackedTensorAccessor64; |
| 15 | +using torch::RestrictPtrTraits; |
| 16 | + |
| 17 | +// A chunk of work is blocksize-many points. |
| 18 | +// There are N clouds in the batch, and P points in each cloud. |
| 19 | +// The number of potential chunks to do per cloud is (1+(P-1)/blocksize), |
| 20 | +// which we call chunks_per_cloud. |
| 21 | +// These (N*chunks_per_cloud) chunks are divided among the gridSize-many blocks. |
| 22 | +// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc . |
| 23 | +// In chunk i, we work on cloud (i/chunks_per_cloud) on points starting from |
| 24 | +// blocksize*(i%chunks_per_cloud). |
| 25 | + |
| 26 | +// Explanation of the calculation is in the cpp file. |
| 27 | + |
| 28 | +// EightDirections(t) runs t(a,b,c) for every combination of boolean a, b, c. |
| 29 | +template <class T> |
| 30 | +static __device__ void EightDirections(T&& t) { |
| 31 | + t(false, false, false); |
| 32 | + t(false, false, true); |
| 33 | + t(false, true, false); |
| 34 | + t(false, true, true); |
| 35 | + t(true, false, false); |
| 36 | + t(true, false, true); |
| 37 | + t(true, true, false); |
| 38 | + t(true, true, true); |
| 39 | +} |
| 40 | + |
| 41 | +__global__ void PointsToVolumesForwardKernel( |
| 42 | + const PackedTensorAccessor64<float, 3, RestrictPtrTraits> points_3d, |
| 43 | + const PackedTensorAccessor64<float, 3, RestrictPtrTraits> points_features, |
| 44 | + PackedTensorAccessor64<float, 5, RestrictPtrTraits> volume_densities, |
| 45 | + PackedTensorAccessor64<float, 5, RestrictPtrTraits> volume_features, |
| 46 | + PackedTensorAccessor64<int64_t, 2, RestrictPtrTraits> grid_sizes, |
| 47 | + PackedTensorAccessor64<float, 2, RestrictPtrTraits> mask, |
| 48 | + const float point_weight, |
| 49 | + const bool align_corners, |
| 50 | + const bool splat, |
| 51 | + const int64_t batch_size, |
| 52 | + const int64_t P, |
| 53 | + const int64_t n_features) { |
| 54 | + const int64_t chunks_per_cloud = (1 + (P - 1) / blockDim.x); |
| 55 | + const int64_t chunks_to_do = batch_size * chunks_per_cloud; |
| 56 | + const int scale_offset = align_corners ? 1 : 0; |
| 57 | + const float offset = align_corners ? 0 : 0.5; |
| 58 | + for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { |
| 59 | + const int64_t batch_index = chunk / chunks_per_cloud; |
| 60 | + const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); |
| 61 | + int64_t point_idx = start_point + threadIdx.x; |
| 62 | + if (point_idx >= P) { |
| 63 | + continue; |
| 64 | + } |
| 65 | + if (mask[batch_index][point_idx] == 0) { |
| 66 | + continue; |
| 67 | + } |
| 68 | + auto volume_densities_aa = volume_densities[batch_index][0]; |
| 69 | + auto volume_features_aa = volume_features[batch_index]; |
| 70 | + auto point = points_3d[batch_index][point_idx]; |
| 71 | + auto point_features = points_features[batch_index][point_idx]; |
| 72 | + const int64_t grid_size_x = grid_sizes[batch_index][2]; |
| 73 | + const int64_t grid_size_y = grid_sizes[batch_index][1]; |
| 74 | + const int64_t grid_size_z = grid_sizes[batch_index][0]; |
| 75 | + auto increment_location = |
| 76 | + [&](int64_t x, int64_t y, int64_t z, float weight) { |
| 77 | + if (x >= grid_size_x || y >= grid_size_y || z >= grid_size_z) { |
| 78 | + return; |
| 79 | + } |
| 80 | + if (x < 0 || y < 0 || z < 0) { |
| 81 | + return; |
| 82 | + } |
| 83 | + |
| 84 | + atomicAdd(&volume_densities_aa[z][y][x], weight * point_weight); |
| 85 | + |
| 86 | + for (int64_t feature_idx = 0; feature_idx < n_features; |
| 87 | + ++feature_idx) { |
| 88 | + atomicAdd( |
| 89 | + &volume_features_aa[feature_idx][z][y][x], |
| 90 | + point_features[feature_idx] * weight * point_weight); |
| 91 | + } |
| 92 | + }; |
| 93 | + if (!splat) { |
| 94 | + long x = std::lround( |
| 95 | + (point[0] + 1) * 0.5 * (grid_size_x - scale_offset) - offset); |
| 96 | + long y = std::lround( |
| 97 | + (point[1] + 1) * 0.5 * (grid_size_y - scale_offset) - offset); |
| 98 | + long z = std::lround( |
| 99 | + (point[2] + 1) * 0.5 * (grid_size_z - scale_offset) - offset); |
| 100 | + increment_location(x, y, z, 1); |
| 101 | + } else { |
| 102 | + float x = 0, y = 0, z = 0; |
| 103 | + float rx = std::modf( |
| 104 | + (point[0] + 1) * 0.5 * (grid_size_x - scale_offset) - offset, &x); |
| 105 | + float ry = std::modf( |
| 106 | + (point[1] + 1) * 0.5 * (grid_size_y - scale_offset) - offset, &y); |
| 107 | + float rz = std::modf( |
| 108 | + (point[2] + 1) * 0.5 * (grid_size_z - scale_offset) - offset, &z); |
| 109 | + auto handle_point = [&](bool up_x, bool up_y, bool up_z) { |
| 110 | + float weight = |
| 111 | + (up_x ? rx : 1 - rx) * (up_y ? ry : 1 - ry) * (up_z ? rz : 1 - rz); |
| 112 | + increment_location(x + up_x, y + up_y, z + up_z, weight); |
| 113 | + }; |
| 114 | + EightDirections(handle_point); |
| 115 | + } |
| 116 | + } |
| 117 | +} |
| 118 | + |
| 119 | +void PointsToVolumesForwardCuda( |
| 120 | + const torch::Tensor& points_3d, |
| 121 | + const torch::Tensor& points_features, |
| 122 | + const torch::Tensor& volume_densities, |
| 123 | + const torch::Tensor& volume_features, |
| 124 | + const torch::Tensor& grid_sizes, |
| 125 | + const torch::Tensor& mask, |
| 126 | + const float point_weight, |
| 127 | + const bool align_corners, |
| 128 | + const bool splat) { |
| 129 | + // Check inputs are on the same device |
| 130 | + at::TensorArg points_3d_t{points_3d, "points_3d", 1}, |
| 131 | + points_features_t{points_features, "points_features", 2}, |
| 132 | + volume_densities_t{volume_densities, "volume_densities", 3}, |
| 133 | + volume_features_t{volume_features, "volume_features", 4}, |
| 134 | + grid_sizes_t{grid_sizes, "grid_sizes", 5}, mask_t{mask, "mask", 6}; |
| 135 | + at::CheckedFrom c = "PointsToVolumesForwardCuda"; |
| 136 | + at::checkAllSameGPU( |
| 137 | + c, |
| 138 | + {points_3d_t, |
| 139 | + points_features_t, |
| 140 | + volume_densities_t, |
| 141 | + volume_features_t, |
| 142 | + grid_sizes_t, |
| 143 | + mask_t}); |
| 144 | + |
| 145 | + // Set the device for the kernel launch based on the device of the input |
| 146 | + at::cuda::CUDAGuard device_guard(points_3d.device()); |
| 147 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| 148 | + |
| 149 | + const int blocks = 1024; |
| 150 | + const int threads = 32; |
| 151 | + |
| 152 | + const int64_t batch_size = points_3d.size(0); |
| 153 | + const int64_t P = points_3d.size(1); |
| 154 | + const int64_t n_features = points_features.size(2); |
| 155 | + |
| 156 | + PointsToVolumesForwardKernel<<<blocks, threads, 0, stream>>>( |
| 157 | + points_3d.packed_accessor64<float, 3, RestrictPtrTraits>(), |
| 158 | + points_features.packed_accessor64<float, 3, RestrictPtrTraits>(), |
| 159 | + volume_densities.packed_accessor64<float, 5, RestrictPtrTraits>(), |
| 160 | + volume_features.packed_accessor64<float, 5, RestrictPtrTraits>(), |
| 161 | + grid_sizes.packed_accessor64<int64_t, 2, RestrictPtrTraits>(), |
| 162 | + mask.packed_accessor64<float, 2, RestrictPtrTraits>(), |
| 163 | + point_weight, |
| 164 | + align_corners, |
| 165 | + splat, |
| 166 | + batch_size, |
| 167 | + P, |
| 168 | + n_features); |
| 169 | +} |
| 170 | + |
| 171 | +__global__ void PointsToVolumesBackwardKernel( |
| 172 | + const PackedTensorAccessor64<float, 3, RestrictPtrTraits> points_3d, |
| 173 | + const PackedTensorAccessor64<float, 3, RestrictPtrTraits> points_features, |
| 174 | + const PackedTensorAccessor64<int64_t, 2, RestrictPtrTraits> grid_sizes, |
| 175 | + const PackedTensorAccessor64<float, 2, RestrictPtrTraits> mask, |
| 176 | + PackedTensorAccessor64<float, 5, RestrictPtrTraits> grad_volume_densities, |
| 177 | + PackedTensorAccessor64<float, 5, RestrictPtrTraits> grad_volume_features, |
| 178 | + PackedTensorAccessor64<float, 3, RestrictPtrTraits> grad_points_3d, |
| 179 | + PackedTensorAccessor64<float, 3, RestrictPtrTraits> grad_points_features, |
| 180 | + const float point_weight, |
| 181 | + const bool align_corners, |
| 182 | + const bool splat, |
| 183 | + const int64_t batch_size, |
| 184 | + const int64_t P, |
| 185 | + const int64_t n_features) { |
| 186 | + const int64_t chunks_per_cloud = (1 + (P - 1) / blockDim.x); |
| 187 | + const int64_t chunks_to_do = batch_size * chunks_per_cloud; |
| 188 | + const int scale_offset = align_corners ? 1 : 0; |
| 189 | + const float offset = align_corners ? 0 : 0.5; |
| 190 | + // Note that the gradients belonging to each point are only touched by |
| 191 | + // a single thread in one of our "chunks", which is in a single block. |
| 192 | + // So unlike in the forward pass, there's no need for atomics here. |
| 193 | + for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { |
| 194 | + const int64_t batch_index = chunk / chunks_per_cloud; |
| 195 | + const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); |
| 196 | + int64_t point_idx = start_point + threadIdx.x; |
| 197 | + if (point_idx >= P) { |
| 198 | + continue; |
| 199 | + } |
| 200 | + if (mask[batch_index][point_idx] == 0) { |
| 201 | + continue; |
| 202 | + } |
| 203 | + auto point = points_3d[batch_index][point_idx]; |
| 204 | + auto point_features = points_features[batch_index][point_idx]; |
| 205 | + auto grad_point = grad_points_3d[batch_index][point_idx]; |
| 206 | + auto grad_point_features = grad_points_features[batch_index][point_idx]; |
| 207 | + auto grad_volume_densities_a = grad_volume_densities[batch_index][0]; |
| 208 | + auto grad_volume_features_a = grad_volume_features[batch_index]; |
| 209 | + const int64_t grid_size_x = grid_sizes[batch_index][2]; |
| 210 | + const int64_t grid_size_y = grid_sizes[batch_index][1]; |
| 211 | + const int64_t grid_size_z = grid_sizes[batch_index][0]; |
| 212 | + |
| 213 | + auto increment_location = |
| 214 | + [&](int64_t x, int64_t y, int64_t z, float weight) { |
| 215 | + if (x >= grid_size_x || y >= grid_size_y || z >= grid_size_z) { |
| 216 | + return false; |
| 217 | + } |
| 218 | + if (x < 0 || y < 0 || z < 0) { |
| 219 | + return false; |
| 220 | + } |
| 221 | + |
| 222 | + // This is a forward line, for comparison |
| 223 | + // volume_densities_aa[z][y][x] += weight * point_weight; |
| 224 | + |
| 225 | + for (int64_t feature_idx = 0; feature_idx < n_features; |
| 226 | + ++feature_idx) { |
| 227 | + // This is a forward line, for comparison |
| 228 | + // volume_features_aa[feature_idx][z][y][x] += |
| 229 | + // point_features[feature_idx] * weight * point_weight; |
| 230 | + grad_point_features[feature_idx] += |
| 231 | + grad_volume_features_a[feature_idx][z][y][x] * weight * |
| 232 | + point_weight; |
| 233 | + } |
| 234 | + return true; |
| 235 | + }; |
| 236 | + |
| 237 | + if (!splat) { |
| 238 | + long x = std::lround( |
| 239 | + (point[0] + 1) * 0.5 * (grid_size_x - scale_offset) - offset); |
| 240 | + long y = std::lround( |
| 241 | + (point[1] + 1) * 0.5 * (grid_size_y - scale_offset) - offset); |
| 242 | + long z = std::lround( |
| 243 | + (point[2] + 1) * 0.5 * (grid_size_z - scale_offset) - offset); |
| 244 | + increment_location(x, y, z, 1); |
| 245 | + } else { |
| 246 | + float x = 0, y = 0, z = 0; |
| 247 | + float rx = std::modf( |
| 248 | + (point[0] + 1) * 0.5 * (grid_size_x - scale_offset) - offset, &x); |
| 249 | + float ry = std::modf( |
| 250 | + (point[1] + 1) * 0.5 * (grid_size_y - scale_offset) - offset, &y); |
| 251 | + float rz = std::modf( |
| 252 | + (point[2] + 1) * 0.5 * (grid_size_z - scale_offset) - offset, &z); |
| 253 | + auto handle_point = [&](bool up_x, bool up_y, bool up_z) { |
| 254 | + float weight_x = (up_x ? rx : 1 - rx); |
| 255 | + float weight_y = (up_y ? ry : 1 - ry); |
| 256 | + float weight_z = (up_z ? rz : 1 - rz); |
| 257 | + float weight = weight_x * weight_y * weight_z; |
| 258 | + if (increment_location(x + up_x, y + up_y, z + up_z, weight)) { |
| 259 | + // weight * point_weight has been added to |
| 260 | + // volume_densities_aa[z+up_z][y+up_y][x+up_x] |
| 261 | + // Also for each feature_idx, |
| 262 | + // point_features[feature_idx] * weight * point_weight |
| 263 | + // has been added to |
| 264 | + // volume_features_aa[feature_idx][z+up_z][y+up_y][x+up_x] |
| 265 | + |
| 266 | + double source_gradient = |
| 267 | + grad_volume_densities_a[z + up_z][y + up_y][x + up_x]; |
| 268 | + for (int64_t feature_idx = 0; feature_idx < n_features; |
| 269 | + ++feature_idx) { |
| 270 | + source_gradient += point_features[feature_idx] * |
| 271 | + grad_volume_features_a[feature_idx][z + up_z][y + up_y] |
| 272 | + [x + up_x]; |
| 273 | + } |
| 274 | + grad_point[0] += source_gradient * (up_x ? 1 : -1) * weight_y * |
| 275 | + weight_z * 0.5 * (grid_size_x - scale_offset) * point_weight; |
| 276 | + grad_point[1] += source_gradient * (up_y ? 1 : -1) * weight_x * |
| 277 | + weight_z * 0.5 * (grid_size_y - scale_offset) * point_weight; |
| 278 | + grad_point[2] += source_gradient * (up_z ? 1 : -1) * weight_x * |
| 279 | + weight_y * 0.5 * (grid_size_z - scale_offset) * point_weight; |
| 280 | + } |
| 281 | + }; |
| 282 | + EightDirections(handle_point); |
| 283 | + } |
| 284 | + } |
| 285 | +} |
| 286 | + |
| 287 | +void PointsToVolumesBackwardCuda( |
| 288 | + const torch::Tensor& points_3d, |
| 289 | + const torch::Tensor& points_features, |
| 290 | + const torch::Tensor& grid_sizes, |
| 291 | + const torch::Tensor& mask, |
| 292 | + const float point_weight, |
| 293 | + const bool align_corners, |
| 294 | + const bool splat, |
| 295 | + const torch::Tensor& grad_volume_densities, |
| 296 | + const torch::Tensor& grad_volume_features, |
| 297 | + const torch::Tensor& grad_points_3d, |
| 298 | + const torch::Tensor& grad_points_features) { |
| 299 | + // Check inputs are on the same device |
| 300 | + at::TensorArg points_3d_t{points_3d, "points_3d", 1}, |
| 301 | + points_features_t{points_features, "points_features", 2}, |
| 302 | + grid_sizes_t{grid_sizes, "grid_sizes", 3}, mask_t{mask, "mask", 4}, |
| 303 | + grad_volume_densities_t{ |
| 304 | + grad_volume_densities, "grad_volume_densities", 8}, |
| 305 | + grad_volume_features_t{grad_volume_features, "grad_volume_features", 9}, |
| 306 | + grad_points_3d_t{grad_points_3d, "grad_points_3d", 10}, |
| 307 | + grad_points_features_t{grad_points_features, "grad_points_features", 11}; |
| 308 | + |
| 309 | + at::CheckedFrom c = "PointsToVolumesBackwardCuda"; |
| 310 | + at::checkAllSameGPU( |
| 311 | + c, |
| 312 | + {points_3d_t, |
| 313 | + points_features_t, |
| 314 | + grid_sizes_t, |
| 315 | + mask_t, |
| 316 | + grad_volume_densities_t, |
| 317 | + grad_volume_features_t, |
| 318 | + grad_points_3d_t, |
| 319 | + grad_points_features_t}); |
| 320 | + |
| 321 | + // Set the device for the kernel launch based on the device of the input |
| 322 | + at::cuda::CUDAGuard device_guard(points_3d.device()); |
| 323 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| 324 | + |
| 325 | + const int blocks = 1024; |
| 326 | + const int threads = 32; |
| 327 | + |
| 328 | + const int64_t batch_size = points_3d.size(0); |
| 329 | + const int64_t P = points_3d.size(1); |
| 330 | + const int64_t n_features = points_features.size(2); |
| 331 | + |
| 332 | + PointsToVolumesBackwardKernel<<<blocks, threads, 0, stream>>>( |
| 333 | + points_3d.packed_accessor64<float, 3, RestrictPtrTraits>(), |
| 334 | + points_features.packed_accessor64<float, 3, RestrictPtrTraits>(), |
| 335 | + grid_sizes.packed_accessor64<int64_t, 2, RestrictPtrTraits>(), |
| 336 | + mask.packed_accessor64<float, 2, RestrictPtrTraits>(), |
| 337 | + grad_volume_densities.packed_accessor64<float, 5, RestrictPtrTraits>(), |
| 338 | + grad_volume_features.packed_accessor64<float, 5, RestrictPtrTraits>(), |
| 339 | + grad_points_3d.packed_accessor64<float, 3, RestrictPtrTraits>(), |
| 340 | + grad_points_features.packed_accessor64<float, 3, RestrictPtrTraits>(), |
| 341 | + point_weight, |
| 342 | + align_corners, |
| 343 | + splat, |
| 344 | + batch_size, |
| 345 | + P, |
| 346 | + n_features); |
| 347 | +} |
0 commit comments