Skip to content

Commit 1ea2b72

Browse files
bottlerfacebook-github-bot
authored andcommitted
sample_pdf CUDA and C++ implementations.
Summary: Implement the sample_pdf function from the NeRF project as compiled operators.. The binary search (in searchsorted) is replaced with a low tech linear search, but this is not a problem for the envisaged numbers of bins. Reviewed By: gkioxari Differential Revision: D26312535 fbshipit-source-id: df1c3119cd63d944380ed1b2657b6ad81d743e49
1 parent 7d7d00f commit 1ea2b72

File tree

7 files changed

+488
-3
lines changed

7 files changed

+488
-3
lines changed

pytorch3d/csrc/ext.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "point_mesh/point_mesh_cuda.h"
2727
#include "rasterize_meshes/rasterize_meshes.h"
2828
#include "rasterize_points/rasterize_points.h"
29+
#include "sample_pdf/sample_pdf.h"
2930

3031
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
3132
m.def("face_areas_normals_forward", &FaceAreasNormalsForward);
@@ -83,6 +84,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
8384
m.def("point_face_array_dist_forward", &PointFaceArrayDistanceForward);
8485
m.def("point_face_array_dist_backward", &PointFaceArrayDistanceBackward);
8586

87+
// Sample PDF
88+
m.def("sample_pdf", &SamplePdf);
89+
8690
// Pulsar.
8791
#ifdef PULSAR_LOGGING_ENABLED
8892
c10::ShowLogInfoToStderr();
+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/ATen.h>
10+
#include <ATen/cuda/CUDAContext.h>
11+
#include <c10/cuda/CUDAGuard.h>
12+
13+
// There is no intermediate memory, so no reason not to have blocksize=32.
14+
// 256 is a reasonable number of blocks.
15+
16+
// DESIGN
17+
// We exploit the fact that n_samples is not tiny.
18+
// A chunk of work is T*blocksize many samples from
19+
// a single batch elememt.
20+
// For each batch element there will be
21+
// chunks_per_batch = 1 + (n_samples-1)/(T*blocksize) of them.
22+
// The number of potential chunks to do is
23+
// n_chunks = chunks_per_batch * n_batches.
24+
// These chunks are divided among the gridSize-many blocks.
25+
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
26+
// In chunk i, we work on batch_element i/chunks_per_batch
27+
// on samples starting from (i%chunks_per_batch) * (T*blocksize)
28+
29+
// BEGIN HYPOTHETICAL
30+
// Another option (not implemented) if batch_size was always large
31+
// would be as follows.
32+
33+
// A chunk of work is S samples from each of blocksize-many
34+
// batch elements.
35+
// For each batch element there will be
36+
// chunks_per_batch = (1+(n_samples-1)/S) of them.
37+
// The number of potential chunks to do is
38+
// n_chunks = chunks_per_batch * (1+(n_batches-1)/blocksize)
39+
// These chunks are divided among the gridSize-many blocks.
40+
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
41+
// In chunk i, we work on samples starting from S*(i%chunks_per_batch)
42+
// on batch elements starting from blocksize*(i/chunks_per_batch).
43+
// END HYPOTHETICAL
44+
45+
__global__ void SamplePdfCudaKernel(
46+
const float* __restrict__ bins,
47+
const float* __restrict__ weights,
48+
float* __restrict__ outputs,
49+
float eps,
50+
const int T,
51+
const int64_t batch_size,
52+
const int64_t n_bins,
53+
const int64_t n_samples) {
54+
const int64_t chunks_per_batch = 1 + (n_samples - 1) / (T * blockDim.x);
55+
const int64_t n_chunks = chunks_per_batch * batch_size;
56+
57+
for (int64_t i_chunk = blockIdx.x; i_chunk < n_chunks; i_chunk += gridDim.x) {
58+
// Loop over the chunks.
59+
int64_t i_batch_element = i_chunk / chunks_per_batch;
60+
int64_t sample_start = (i_chunk % chunks_per_batch) * (T * blockDim.x);
61+
const float* const weight_startp = weights + n_bins * i_batch_element;
62+
const float* const bin_startp = bins + (1 + n_bins) * i_batch_element;
63+
64+
// Each chunk looks at a single batch element, so we do the preprocessing
65+
// which depends on the batch element, namely finding the total weight.
66+
// Idenntical work is being done in sync here by every thread of the block.
67+
float total_weight = eps;
68+
for (int64_t i_bin = 0; i_bin < n_bins; ++i_bin) {
69+
total_weight += weight_startp[i_bin];
70+
}
71+
72+
float* const output_startp =
73+
outputs + n_samples * i_batch_element + sample_start;
74+
75+
for (int t = 0; t < T; ++t) {
76+
// Loop over T, which is the number of samples each thread makes within
77+
// the chunk.
78+
const int64_t i_sample_within_chunk = threadIdx.x + t * blockDim.x;
79+
if (sample_start + i_sample_within_chunk >= n_samples) {
80+
// Some threads need to exit early because the sample they would
81+
// make is unwanted.
82+
continue;
83+
}
84+
// output_startp[i_sample_within_chunk] contains the quantile we (i.e.
85+
// this thread) are calcvulating.
86+
float uniform = total_weight * output_startp[i_sample_within_chunk];
87+
int64_t i_bin = 0;
88+
// We find the bin containing the quantile by walking along the weights.
89+
// This loop must be thread dependent. I.e. the whole warp will wait until
90+
// every thread has found the bin for its quantile.
91+
// It may be best to write it differently.
92+
while (i_bin + 1 < n_bins && uniform > weight_startp[i_bin]) {
93+
uniform -= weight_startp[i_bin];
94+
++i_bin;
95+
}
96+
97+
// Now we know which bin to look in, we use linear interpolation
98+
// to find the location of the quantile within the bin, and
99+
// write the answer back.
100+
float bin_start = bin_startp[i_bin];
101+
float bin_end = bin_startp[i_bin + 1];
102+
float bin_weight = weight_startp[i_bin];
103+
float output_value = bin_start;
104+
if (uniform > bin_weight) {
105+
output_value = bin_end;
106+
} else if (bin_weight > eps) {
107+
output_value += (uniform / bin_weight) * (bin_end - bin_start);
108+
}
109+
output_startp[i_sample_within_chunk] = output_value;
110+
}
111+
}
112+
}
113+
114+
void SamplePdfCuda(
115+
const at::Tensor& bins,
116+
const at::Tensor& weights,
117+
const at::Tensor& outputs,
118+
float eps) {
119+
// Check inputs are on the same device
120+
at::TensorArg bins_t{bins, "bins", 1}, weights_t{weights, "weights", 2},
121+
outputs_t{outputs, "outputs", 3};
122+
at::CheckedFrom c = "SamplePdfCuda";
123+
at::checkAllSameGPU(c, {bins_t, weights_t, outputs_t});
124+
at::checkAllSameType(c, {bins_t, weights_t, outputs_t});
125+
126+
// Set the device for the kernel launch based on the device of the input
127+
at::cuda::CUDAGuard device_guard(bins.device());
128+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
129+
130+
const int64_t batch_size = bins.size(0);
131+
const int64_t n_bins = weights.size(1);
132+
const int64_t n_samples = outputs.size(1);
133+
134+
const int64_t threads = 32;
135+
const int64_t T = n_samples <= threads ? 1 : 2;
136+
const int64_t chunks_per_batch = 1 + (n_samples - 1) / (T * threads);
137+
const int64_t n_chunks = chunks_per_batch * batch_size;
138+
139+
const int64_t max_blocks = 1024;
140+
const int64_t blocks = n_chunks < max_blocks ? n_chunks : max_blocks;
141+
142+
SamplePdfCudaKernel<<<blocks, threads, 0, stream>>>(
143+
bins.contiguous().data_ptr<float>(),
144+
weights.contiguous().data_ptr<float>(),
145+
outputs.data_ptr<float>(), // Checked contiguous in header file.
146+
eps,
147+
T,
148+
batch_size,
149+
n_bins,
150+
n_samples);
151+
152+
AT_CUDA_CHECK(cudaGetLastError());
153+
}
+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
#include <torch/extension.h>
11+
#include <cstdio>
12+
#include <tuple>
13+
#include "utils/pytorch3d_cutils.h"
14+
15+
// ****************************************************************************
16+
// * SamplePdf *
17+
// ****************************************************************************
18+
19+
// Samples a probability density functions defined by bin edges `bins` and
20+
// the non-negative per-bin probabilities `weights`.
21+
22+
// Args:
23+
// bins: FloatTensor of shape `(batch_size, n_bins+1)` denoting the edges
24+
// of the sampling bins.
25+
26+
// weights: FloatTensor of shape `(batch_size, n_bins)` containing
27+
// non-negative numbers representing the probability of sampling the
28+
// corresponding bin.
29+
30+
// uniforms: The quantiles to draw, FloatTensor of shape
31+
// `(batch_size, n_samples)`.
32+
33+
// outputs: On call, this contains the quantiles to draw. It is overwritten
34+
// with the drawn samples. FloatTensor of shape
35+
// `(batch_size, n_samples), where `n_samples are drawn from each
36+
// distribution.
37+
38+
// eps: A constant preventing division by zero in case empty bins are
39+
// present.
40+
41+
// Not differentiable
42+
43+
#ifdef WITH_CUDA
44+
void SamplePdfCuda(
45+
const torch::Tensor& bins,
46+
const torch::Tensor& weights,
47+
const torch::Tensor& outputs,
48+
float eps);
49+
#endif
50+
51+
void SamplePdfCpu(
52+
const torch::Tensor& bins,
53+
const torch::Tensor& weights,
54+
const torch::Tensor& outputs,
55+
float eps);
56+
57+
inline void SamplePdf(
58+
const torch::Tensor& bins,
59+
const torch::Tensor& weights,
60+
const torch::Tensor& outputs,
61+
float eps) {
62+
if (bins.is_cuda()) {
63+
#ifdef WITH_CUDA
64+
CHECK_CUDA(weights);
65+
CHECK_CONTIGUOUS_CUDA(outputs);
66+
SamplePdfCuda(bins, weights, outputs, eps);
67+
return;
68+
#else
69+
AT_ERROR("Not compiled with GPU support.");
70+
#endif
71+
}
72+
CHECK_CONTIGUOUS(outputs);
73+
SamplePdfCpu(bins, weights, outputs, eps);
74+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <torch/extension.h>
10+
#include <algorithm>
11+
#include <thread>
12+
#include <vector>
13+
14+
// If the number of bins is the typical 64, it is
15+
// quicker to use binary search than linear scan.
16+
// With more bins, it is more important.
17+
// There is no equivalent CUDA implementation yet.
18+
#define USE_BINARY_SEARCH
19+
20+
namespace {
21+
// This worker function does the job of SamplePdf but only on
22+
// batch elements in [start_batch, end_batch).
23+
void SamplePdfCpu_worker(
24+
const torch::Tensor& bins,
25+
const torch::Tensor& weights,
26+
const torch::Tensor& outputs,
27+
float eps,
28+
int64_t start_batch,
29+
int64_t end_batch) {
30+
const int64_t n_bins = weights.size(1);
31+
const int64_t n_samples = outputs.size(1);
32+
33+
auto bins_a = bins.accessor<float, 2>();
34+
auto weights_a = weights.accessor<float, 2>();
35+
float* __restrict__ output_p =
36+
outputs.data_ptr<float>() + start_batch * n_samples;
37+
38+
#ifdef USE_BINARY_SEARCH
39+
std::vector<float> partial_sums(n_bins);
40+
#endif
41+
42+
for (int64_t i_batch_elt = start_batch; i_batch_elt < end_batch;
43+
++i_batch_elt) {
44+
auto bin_a = bins_a[i_batch_elt];
45+
auto weight_a = weights_a[i_batch_elt];
46+
47+
// Here we do the work which has to be done once per batch element.
48+
// i.e. (1) finding the total weight. (2) If using binary search,
49+
// precompute the partial sums of the weights.
50+
51+
float total_weight = 0;
52+
for (int64_t i_bin = 0; i_bin < n_bins; ++i_bin) {
53+
total_weight += weight_a[i_bin];
54+
#ifdef USE_BINARY_SEARCH
55+
partial_sums[i_bin] = total_weight;
56+
#endif
57+
}
58+
total_weight += eps;
59+
60+
for (int64_t i_sample = 0; i_sample < n_samples; ++i_sample) {
61+
// Here we are taking a single random quantile (which is stored
62+
// in *output_p) and using it to make a single sample, which we
63+
// write back to the same location. First we find which bin
64+
// the quantile lives in, either by binary search in the
65+
// precomputed partial sums, or by scanning through the weights.
66+
67+
float uniform = total_weight * *output_p;
68+
#ifdef USE_BINARY_SEARCH
69+
int64_t i_bin = std::lower_bound(
70+
partial_sums.begin(), --partial_sums.end(), uniform) -
71+
partial_sums.begin();
72+
if (i_bin > 0) {
73+
uniform -= partial_sums[i_bin - 1];
74+
}
75+
#else
76+
int64_t i_bin = 0;
77+
while (i_bin + 1 < n_bins && uniform > weight_a[i_bin]) {
78+
uniform -= weight_a[i_bin];
79+
++i_bin;
80+
}
81+
#endif
82+
83+
// Now i_bin identifies the bin the quantile lives in, we use
84+
// straight line interpolation to find the position of the
85+
// quantile within the bin, and write it to *output_p.
86+
87+
float bin_start = bin_a[i_bin];
88+
float bin_end = bin_a[i_bin + 1];
89+
float bin_weight = weight_a[i_bin];
90+
float output_value = bin_start;
91+
if (uniform > bin_weight) {
92+
output_value = bin_end;
93+
} else if (bin_weight > eps) {
94+
output_value += (uniform / bin_weight) * (bin_end - bin_start);
95+
}
96+
*output_p = output_value;
97+
++output_p;
98+
}
99+
}
100+
}
101+
102+
} // anonymous namespace
103+
104+
void SamplePdfCpu(
105+
const torch::Tensor& bins,
106+
const torch::Tensor& weights,
107+
const torch::Tensor& outputs,
108+
float eps) {
109+
const int64_t batch_size = bins.size(0);
110+
const int64_t max_threads = std::min(4, at::get_num_threads());
111+
const int64_t n_threads = std::min(max_threads, batch_size);
112+
if (batch_size == 0) {
113+
return;
114+
}
115+
116+
// SamplePdfCpu_worker does the work of this function. We send separate ranges
117+
// of batch elements to that function in nThreads-1 separate threads.
118+
119+
std::vector<std::thread> threads;
120+
threads.reserve(n_threads - 1);
121+
const int64_t batch_elements_per_thread = 1 + (batch_size - 1) / n_threads;
122+
int64_t start_batch = 0;
123+
for (int iThread = 0; iThread < n_threads - 1; ++iThread) {
124+
threads.emplace_back(
125+
SamplePdfCpu_worker,
126+
bins,
127+
weights,
128+
outputs,
129+
eps,
130+
start_batch,
131+
start_batch + batch_elements_per_thread);
132+
start_batch += batch_elements_per_thread;
133+
}
134+
135+
// The remaining batch elements are calculated in this threads. If nThreads is
136+
// 1 then all the work happens in this line.
137+
SamplePdfCpu_worker(bins, weights, outputs, eps, start_batch, batch_size);
138+
for (auto&& thread : threads) {
139+
thread.join();
140+
}
141+
}

0 commit comments

Comments
 (0)