2
2
3
3
#include < ATen/ATen.h>
4
4
#include < ATen/core/TensorAccessor.h>
5
+ #include < ATen/cuda/CUDAContext.h>
6
+ #include < c10/cuda/CUDAGuard.h>
5
7
6
8
#include < cuda.h>
7
9
#include < cuda_runtime.h>
@@ -151,26 +153,43 @@ at::Tensor weightedSumNormCudaForward(
151
153
const at::Tensor& features,
152
154
const at::Tensor& alphas,
153
155
const at::Tensor& points_idx) {
156
+ // Check inputs are on the same device
157
+ at::TensorArg features_t {features, " features" , 1 },
158
+ alphas_t {alphas, " alphas" , 2 }, points_idx_t {points_idx, " points_idx" , 3 };
159
+ at::CheckedFrom c = " weightedSumNormCudaForward" ;
160
+ at::checkAllSameGPU (c, {features_t , alphas_t , points_idx_t });
161
+ at::checkAllSameType (c, {features_t , alphas_t });
162
+
163
+ // Set the device for the kernel launch based on the device of the input
164
+ at::cuda::CUDAGuard device_guard (features.device ());
165
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
166
+
154
167
const int64_t batch_size = points_idx.size (0 );
155
168
const int64_t C = features.size (0 );
156
169
const int64_t H = points_idx.size (2 );
157
170
const int64_t W = points_idx.size (3 );
158
171
159
172
auto result = at::zeros ({batch_size, C, H, W}, features.options ());
160
173
174
+ if (result.numel () == 0 ) {
175
+ AT_CUDA_CHECK (cudaGetLastError ());
176
+ return result;
177
+ }
178
+
161
179
const dim3 threadsPerBlock (64 );
162
180
const dim3 numBlocks (batch_size, 1024 / batch_size + 1 );
163
181
164
182
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
165
183
// doubles. Currently, support is for floats only.
166
184
// clang-format off
167
- weightedSumNormCudaForwardKernel<<<numBlocks, threadsPerBlock>>> (
185
+ weightedSumNormCudaForwardKernel<<<numBlocks, threadsPerBlock, 0 , stream >>> (
168
186
result.packed_accessor64 <float , 4 , at::RestrictPtrTraits>(),
169
187
features.packed_accessor64 <float , 2 , at::RestrictPtrTraits>(),
170
188
alphas.packed_accessor64 <float , 4 , at::RestrictPtrTraits>(),
171
189
points_idx.packed_accessor64 <int64_t , 4 , at::RestrictPtrTraits>());
172
190
// clang-format on
173
191
192
+ AT_CUDA_CHECK (cudaGetLastError ());
174
193
return result;
175
194
}
176
195
@@ -179,17 +198,34 @@ std::tuple<at::Tensor, at::Tensor> weightedSumNormCudaBackward(
179
198
const at::Tensor& features,
180
199
const at::Tensor& alphas,
181
200
const at::Tensor& points_idx) {
201
+ // Check inputs are on the same device
202
+ at::TensorArg grad_outputs_t {grad_outputs, " grad_outputs" , 1 },
203
+ features_t {features, " features" , 2 }, alphas_t {alphas, " alphas" , 3 },
204
+ points_idx_t {points_idx, " points_idx" , 4 };
205
+ at::CheckedFrom c = " weightedSumNormCudaBackward" ;
206
+ at::checkAllSameGPU (c, {grad_outputs_t , features_t , alphas_t , points_idx_t });
207
+ at::checkAllSameType (c, {grad_outputs_t , features_t , alphas_t });
208
+
209
+ // Set the device for the kernel launch based on the device of the input
210
+ at::cuda::CUDAGuard device_guard (features.device ());
211
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
212
+
182
213
auto grad_features = at::zeros_like (features);
183
214
auto grad_alphas = at::zeros_like (alphas);
184
215
216
+ if (grad_features.numel () == 0 || grad_alphas.numel () == 0 ) {
217
+ AT_CUDA_CHECK (cudaGetLastError ());
218
+ return std::make_tuple (grad_features, grad_alphas);
219
+ }
220
+
185
221
const int64_t bs = points_idx.size (0 );
186
222
187
223
const dim3 threadsPerBlock (64 );
188
224
const dim3 numBlocks (bs, 1024 / bs + 1 );
189
225
190
226
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
191
227
// doubles. Currently, support is for floats only.
192
- weightedSumNormCudaBackwardKernel<<<numBlocks, threadsPerBlock>>> (
228
+ weightedSumNormCudaBackwardKernel<<<numBlocks, threadsPerBlock, 0 , stream >>> (
193
229
// clang-format off
194
230
grad_features.packed_accessor64 <float , 2 , at::RestrictPtrTraits>(),
195
231
grad_alphas.packed_accessor64 <float , 4 , at::RestrictPtrTraits>(),
@@ -198,6 +234,6 @@ std::tuple<at::Tensor, at::Tensor> weightedSumNormCudaBackward(
198
234
alphas.packed_accessor64 <float , 4 , at::RestrictPtrTraits>(),
199
235
points_idx.packed_accessor64 <int64_t , 4 , at::RestrictPtrTraits>());
200
236
// clang-format on
201
-
237
+ AT_CUDA_CHECK ( cudaGetLastError ());
202
238
return std::make_tuple (grad_features, grad_alphas);
203
239
}
0 commit comments