|
| 1 | +/* |
| 2 | + * Copyright (c) Facebook, Inc. and its affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include <ATen/ATen.h> |
| 10 | +#include <ATen/cuda/CUDAContext.h> |
| 11 | +#include <c10/cuda/CUDAGuard.h> |
| 12 | + |
| 13 | +// There is no intermediate memory, so no reason not to have blocksize=32. |
| 14 | +// 256 is a reasonable number of blocks. |
| 15 | + |
| 16 | +// DESIGN |
| 17 | +// We exploit the fact that n_samples is not tiny. |
| 18 | +// A chunk of work is T*blocksize many samples from |
| 19 | +// a single batch elememt. |
| 20 | +// For each batch element there will be |
| 21 | +// chunks_per_batch = 1 + (n_samples-1)/(T*blocksize) of them. |
| 22 | +// The number of potential chunks to do is |
| 23 | +// n_chunks = chunks_per_batch * n_batches. |
| 24 | +// These chunks are divided among the gridSize-many blocks. |
| 25 | +// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc . |
| 26 | +// In chunk i, we work on batch_element i/chunks_per_batch |
| 27 | +// on samples starting from (i%chunks_per_batch) * (T*blocksize) |
| 28 | + |
| 29 | +// BEGIN HYPOTHETICAL |
| 30 | +// Another option (not implemented) if batch_size was always large |
| 31 | +// would be as follows. |
| 32 | + |
| 33 | +// A chunk of work is S samples from each of blocksize-many |
| 34 | +// batch elements. |
| 35 | +// For each batch element there will be |
| 36 | +// chunks_per_batch = (1+(n_samples-1)/S) of them. |
| 37 | +// The number of potential chunks to do is |
| 38 | +// n_chunks = chunks_per_batch * (1+(n_batches-1)/blocksize) |
| 39 | +// These chunks are divided among the gridSize-many blocks. |
| 40 | +// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc . |
| 41 | +// In chunk i, we work on samples starting from S*(i%chunks_per_batch) |
| 42 | +// on batch elements starting from blocksize*(i/chunks_per_batch). |
| 43 | +// END HYPOTHETICAL |
| 44 | + |
| 45 | +__global__ void SamplePdfCudaKernel( |
| 46 | + const float* __restrict__ bins, |
| 47 | + const float* __restrict__ weights, |
| 48 | + float* __restrict__ outputs, |
| 49 | + float eps, |
| 50 | + const int T, |
| 51 | + const int64_t batch_size, |
| 52 | + const int64_t n_bins, |
| 53 | + const int64_t n_samples) { |
| 54 | + const int64_t chunks_per_batch = 1 + (n_samples - 1) / (T * blockDim.x); |
| 55 | + const int64_t n_chunks = chunks_per_batch * batch_size; |
| 56 | + |
| 57 | + for (int64_t i_chunk = blockIdx.x; i_chunk < n_chunks; i_chunk += gridDim.x) { |
| 58 | + // Loop over the chunks. |
| 59 | + int64_t i_batch_element = i_chunk / chunks_per_batch; |
| 60 | + int64_t sample_start = (i_chunk % chunks_per_batch) * (T * blockDim.x); |
| 61 | + const float* const weight_startp = weights + n_bins * i_batch_element; |
| 62 | + const float* const bin_startp = bins + (1 + n_bins) * i_batch_element; |
| 63 | + |
| 64 | + // Each chunk looks at a single batch element, so we do the preprocessing |
| 65 | + // which depends on the batch element, namely finding the total weight. |
| 66 | + // Idenntical work is being done in sync here by every thread of the block. |
| 67 | + float total_weight = eps; |
| 68 | + for (int64_t i_bin = 0; i_bin < n_bins; ++i_bin) { |
| 69 | + total_weight += weight_startp[i_bin]; |
| 70 | + } |
| 71 | + |
| 72 | + float* const output_startp = |
| 73 | + outputs + n_samples * i_batch_element + sample_start; |
| 74 | + |
| 75 | + for (int t = 0; t < T; ++t) { |
| 76 | + // Loop over T, which is the number of samples each thread makes within |
| 77 | + // the chunk. |
| 78 | + const int64_t i_sample_within_chunk = threadIdx.x + t * blockDim.x; |
| 79 | + if (sample_start + i_sample_within_chunk >= n_samples) { |
| 80 | + // Some threads need to exit early because the sample they would |
| 81 | + // make is unwanted. |
| 82 | + continue; |
| 83 | + } |
| 84 | + // output_startp[i_sample_within_chunk] contains the quantile we (i.e. |
| 85 | + // this thread) are calcvulating. |
| 86 | + float uniform = total_weight * output_startp[i_sample_within_chunk]; |
| 87 | + int64_t i_bin = 0; |
| 88 | + // We find the bin containing the quantile by walking along the weights. |
| 89 | + // This loop must be thread dependent. I.e. the whole warp will wait until |
| 90 | + // every thread has found the bin for its quantile. |
| 91 | + // It may be best to write it differently. |
| 92 | + while (i_bin + 1 < n_bins && uniform > weight_startp[i_bin]) { |
| 93 | + uniform -= weight_startp[i_bin]; |
| 94 | + ++i_bin; |
| 95 | + } |
| 96 | + |
| 97 | + // Now we know which bin to look in, we use linear interpolation |
| 98 | + // to find the location of the quantile within the bin, and |
| 99 | + // write the answer back. |
| 100 | + float bin_start = bin_startp[i_bin]; |
| 101 | + float bin_end = bin_startp[i_bin + 1]; |
| 102 | + float bin_weight = weight_startp[i_bin]; |
| 103 | + float output_value = bin_start; |
| 104 | + if (uniform > bin_weight) { |
| 105 | + output_value = bin_end; |
| 106 | + } else if (bin_weight > eps) { |
| 107 | + output_value += (uniform / bin_weight) * (bin_end - bin_start); |
| 108 | + } |
| 109 | + output_startp[i_sample_within_chunk] = output_value; |
| 110 | + } |
| 111 | + } |
| 112 | +} |
| 113 | + |
| 114 | +void SamplePdfCuda( |
| 115 | + const at::Tensor& bins, |
| 116 | + const at::Tensor& weights, |
| 117 | + const at::Tensor& outputs, |
| 118 | + float eps) { |
| 119 | + // Check inputs are on the same device |
| 120 | + at::TensorArg bins_t{bins, "bins", 1}, weights_t{weights, "weights", 2}, |
| 121 | + outputs_t{outputs, "outputs", 3}; |
| 122 | + at::CheckedFrom c = "SamplePdfCuda"; |
| 123 | + at::checkAllSameGPU(c, {bins_t, weights_t, outputs_t}); |
| 124 | + at::checkAllSameType(c, {bins_t, weights_t, outputs_t}); |
| 125 | + |
| 126 | + // Set the device for the kernel launch based on the device of the input |
| 127 | + at::cuda::CUDAGuard device_guard(bins.device()); |
| 128 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| 129 | + |
| 130 | + const int64_t batch_size = bins.size(0); |
| 131 | + const int64_t n_bins = weights.size(1); |
| 132 | + const int64_t n_samples = outputs.size(1); |
| 133 | + |
| 134 | + const int64_t threads = 32; |
| 135 | + const int64_t T = n_samples <= threads ? 1 : 2; |
| 136 | + const int64_t chunks_per_batch = 1 + (n_samples - 1) / (T * threads); |
| 137 | + const int64_t n_chunks = chunks_per_batch * batch_size; |
| 138 | + |
| 139 | + const int64_t max_blocks = 1024; |
| 140 | + const int64_t blocks = n_chunks < max_blocks ? n_chunks : max_blocks; |
| 141 | + |
| 142 | + SamplePdfCudaKernel<<<blocks, threads, 0, stream>>>( |
| 143 | + bins.contiguous().data_ptr<float>(), |
| 144 | + weights.contiguous().data_ptr<float>(), |
| 145 | + outputs.data_ptr<float>(), // Checked contiguous in header file. |
| 146 | + eps, |
| 147 | + T, |
| 148 | + batch_size, |
| 149 | + n_bins, |
| 150 | + n_samples); |
| 151 | + |
| 152 | + AT_CUDA_CHECK(cudaGetLastError()); |
| 153 | +} |
0 commit comments