From 936b141c89c3e8cb9f34ca352ce09f1552aa24a2 Mon Sep 17 00:00:00 2001 From: Jonathon Luiten Date: Wed, 12 Jul 2023 16:52:19 -0700 Subject: [PATCH] Added forward pass for depth rendering --- cuda_rasterizer/forward.cu | 19 ++++++++++++++++--- cuda_rasterizer/forward.h | 4 +++- cuda_rasterizer/rasterizer.h | 1 + cuda_rasterizer/rasterizer_impl.cu | 5 ++++- diff_gaussian_rasterization/__init__.py | 7 ++++--- rasterize_points.cu | 6 ++++-- rasterize_points.h | 2 +- 7 files changed, 33 insertions(+), 11 deletions(-) diff --git a/cuda_rasterizer/forward.cu b/cuda_rasterizer/forward.cu index c419a328..9109ae72 100644 --- a/cuda_rasterizer/forward.cu +++ b/cuda_rasterizer/forward.cu @@ -270,7 +270,9 @@ renderCUDA( float* __restrict__ final_T, uint32_t* __restrict__ n_contrib, const float* __restrict__ bg_color, - float* __restrict__ out_color) + float* __restrict__ out_color, + const float* __restrict__ depth, + float* __restrict__ out_depth) { // Identify current tile and associated min/max pixel range. auto block = cg::this_thread_block(); @@ -295,12 +297,14 @@ renderCUDA( __shared__ int collected_id[BLOCK_SIZE]; __shared__ float2 collected_xy[BLOCK_SIZE]; __shared__ float4 collected_conic_opacity[BLOCK_SIZE]; + __shared__ float collected_depth[BLOCK_SIZE]; // Initialize helper variables float T = 1.0f; uint32_t contributor = 0; uint32_t last_contributor = 0; float C[CHANNELS] = { 0 }; + float D = 0.0f; // Iterate over batches until all done or range is complete for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE) @@ -318,6 +322,7 @@ renderCUDA( collected_id[block.thread_rank()] = coll_id; collected_xy[block.thread_rank()] = points_xy_image[coll_id]; collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id]; + collected_depth[block.thread_rank()] = depth[coll_id]; } block.sync(); @@ -354,6 +359,9 @@ renderCUDA( for (int ch = 0; ch < CHANNELS; ch++) C[ch] += features[collected_id[j] * CHANNELS + ch] * alpha * T; + float dep = collected_depth[j]; + D += dep * alpha * T; + T = test_T; // Keep track of last range entry to update this @@ -370,6 +378,7 @@ renderCUDA( n_contrib[pix_id] = last_contributor; for (int ch = 0; ch < CHANNELS; ch++) out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch]; + out_depth[pix_id] = D; } } @@ -384,7 +393,9 @@ void FORWARD::render( float* final_T, uint32_t* n_contrib, const float* bg_color, - float* out_color) + float* out_color, + const float* depth, + float* out_depth) { renderCUDA << > > ( ranges, @@ -396,7 +407,9 @@ void FORWARD::render( final_T, n_contrib, bg_color, - out_color); + out_color, + depth, + out_depth); } void FORWARD::preprocess(int P, int D, int M, diff --git a/cuda_rasterizer/forward.h b/cuda_rasterizer/forward.h index 3c11cb91..f2945869 100644 --- a/cuda_rasterizer/forward.h +++ b/cuda_rasterizer/forward.h @@ -59,7 +59,9 @@ namespace FORWARD float* final_T, uint32_t* n_contrib, const float* bg_color, - float* out_color); + float* out_color, + const float* depth, + float* out_depth); } diff --git a/cuda_rasterizer/rasterizer.h b/cuda_rasterizer/rasterizer.h index 2cde606d..6ede4eb8 100644 --- a/cuda_rasterizer/rasterizer.h +++ b/cuda_rasterizer/rasterizer.h @@ -49,6 +49,7 @@ namespace CudaRasterizer const float tan_fovx, float tan_fovy, const bool prefiltered, float* out_color, + float* out_depth, int* radii = nullptr); static void backward( diff --git a/cuda_rasterizer/rasterizer_impl.cu b/cuda_rasterizer/rasterizer_impl.cu index d7b9d6ab..d79f0bb8 100644 --- a/cuda_rasterizer/rasterizer_impl.cu +++ b/cuda_rasterizer/rasterizer_impl.cu @@ -216,6 +216,7 @@ int CudaRasterizer::Rasterizer::forward( const float tan_fovx, float tan_fovy, const bool prefiltered, float* out_color, + float* out_depth, int* radii) { const float focal_y = height / (2.0f * tan_fovy); @@ -330,7 +331,9 @@ int CudaRasterizer::Rasterizer::forward( imgState.accum_alpha, imgState.n_contrib, background, - out_color); + out_color, + geomState.depths, + out_depth); return num_rendered; } diff --git a/diff_gaussian_rasterization/__init__.py b/diff_gaussian_rasterization/__init__.py index 4b072f7a..485c8cc6 100644 --- a/diff_gaussian_rasterization/__init__.py +++ b/diff_gaussian_rasterization/__init__.py @@ -75,16 +75,17 @@ def forward( ) # Invoke C++/CUDA rasterizer - num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) + # num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) + num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer, depth = _C.rasterize_gaussians(*args) # Keep relevant tensors for backward ctx.raster_settings = raster_settings ctx.num_rendered = num_rendered ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer) - return color, radii + return color, radii, depth @staticmethod - def backward(ctx, grad_out_color, _): + def backward(ctx, grad_out_color, _, depth): # Restore necessary values from context num_rendered = ctx.num_rendered diff --git a/rasterize_points.cu b/rasterize_points.cu index 90e10be2..480e4135 100644 --- a/rasterize_points.cu +++ b/rasterize_points.cu @@ -32,7 +32,7 @@ std::function resizeFunctional(torch::Tensor& t) { return lambda; } -std::tuple +std::tuple RasterizeGaussiansCUDA( const torch::Tensor& background, const torch::Tensor& means3D, @@ -66,6 +66,7 @@ RasterizeGaussiansCUDA( torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts); torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); + torch::Tensor out_depth = torch::full({1, H, W}, 0.0, float_opts); torch::Device device(torch::kCUDA); torch::TensorOptions options(torch::kByte); @@ -107,9 +108,10 @@ RasterizeGaussiansCUDA( tan_fovy, prefiltered, out_color.contiguous().data(), + out_depth.contiguous().data(), radii.contiguous().data()); } - return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer); + return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer, out_depth); } std::tuple diff --git a/rasterize_points.h b/rasterize_points.h index 9be145d6..8f36814d 100644 --- a/rasterize_points.h +++ b/rasterize_points.h @@ -15,7 +15,7 @@ #include #include -std::tuple +std::tuple RasterizeGaussiansCUDA( const torch::Tensor& background, const torch::Tensor& means3D,