Skip to content

Added forward pass for depth rendering #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions cuda_rasterizer/forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,9 @@ renderCUDA(
float* __restrict__ final_T,
uint32_t* __restrict__ n_contrib,
const float* __restrict__ bg_color,
float* __restrict__ out_color)
float* __restrict__ out_color,
const float* __restrict__ depth,
float* __restrict__ out_depth)
{
// Identify current tile and associated min/max pixel range.
auto block = cg::this_thread_block();
Expand All @@ -295,12 +297,14 @@ renderCUDA(
__shared__ int collected_id[BLOCK_SIZE];
__shared__ float2 collected_xy[BLOCK_SIZE];
__shared__ float4 collected_conic_opacity[BLOCK_SIZE];
__shared__ float collected_depth[BLOCK_SIZE];

// Initialize helper variables
float T = 1.0f;
uint32_t contributor = 0;
uint32_t last_contributor = 0;
float C[CHANNELS] = { 0 };
float D = 0.0f;

// Iterate over batches until all done or range is complete
for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE)
Expand All @@ -318,6 +322,7 @@ renderCUDA(
collected_id[block.thread_rank()] = coll_id;
collected_xy[block.thread_rank()] = points_xy_image[coll_id];
collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
collected_depth[block.thread_rank()] = depth[coll_id];
}
block.sync();

Expand Down Expand Up @@ -354,6 +359,9 @@ renderCUDA(
for (int ch = 0; ch < CHANNELS; ch++)
C[ch] += features[collected_id[j] * CHANNELS + ch] * alpha * T;

float dep = collected_depth[j];
D += dep * alpha * T;

T = test_T;

// Keep track of last range entry to update this
Expand All @@ -370,6 +378,7 @@ renderCUDA(
n_contrib[pix_id] = last_contributor;
for (int ch = 0; ch < CHANNELS; ch++)
out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch];
out_depth[pix_id] = D;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be D / (1 - T) here?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jb-ye I think so, too.

}
}

Expand All @@ -384,7 +393,9 @@ void FORWARD::render(
float* final_T,
uint32_t* n_contrib,
const float* bg_color,
float* out_color)
float* out_color,
const float* depth,
float* out_depth)
{
renderCUDA<NUM_CHANNELS> << <grid, block >> > (
ranges,
Expand All @@ -396,7 +407,9 @@ void FORWARD::render(
final_T,
n_contrib,
bg_color,
out_color);
out_color,
depth,
out_depth);
}

void FORWARD::preprocess(int P, int D, int M,
Expand Down
4 changes: 3 additions & 1 deletion cuda_rasterizer/forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ namespace FORWARD
float* final_T,
uint32_t* n_contrib,
const float* bg_color,
float* out_color);
float* out_color,
const float* depth,
float* out_depth);
}


Expand Down
1 change: 1 addition & 0 deletions cuda_rasterizer/rasterizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ namespace CudaRasterizer
const float tan_fovx, float tan_fovy,
const bool prefiltered,
float* out_color,
float* out_depth,
int* radii = nullptr);

static void backward(
Expand Down
5 changes: 4 additions & 1 deletion cuda_rasterizer/rasterizer_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ int CudaRasterizer::Rasterizer::forward(
const float tan_fovx, float tan_fovy,
const bool prefiltered,
float* out_color,
float* out_depth,
int* radii)
{
const float focal_y = height / (2.0f * tan_fovy);
Expand Down Expand Up @@ -330,7 +331,9 @@ int CudaRasterizer::Rasterizer::forward(
imgState.accum_alpha,
imgState.n_contrib,
background,
out_color);
out_color,
geomState.depths,
out_depth);

return num_rendered;
}
Expand Down
7 changes: 4 additions & 3 deletions diff_gaussian_rasterization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,17 @@ def forward(
)

# Invoke C++/CUDA rasterizer
num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
# num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer, depth = _C.rasterize_gaussians(*args)

# Keep relevant tensors for backward
ctx.raster_settings = raster_settings
ctx.num_rendered = num_rendered
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer)
return color, radii
return color, radii, depth

@staticmethod
def backward(ctx, grad_out_color, _):
def backward(ctx, grad_out_color, _, depth):

# Restore necessary values from context
num_rendered = ctx.num_rendered
Expand Down
6 changes: 4 additions & 2 deletions rasterize_points.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ std::function<char*(size_t N)> resizeFunctional(torch::Tensor& t) {
return lambda;
}

std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA(
const torch::Tensor& background,
const torch::Tensor& means3D,
Expand Down Expand Up @@ -66,6 +66,7 @@ RasterizeGaussiansCUDA(

torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts);
torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
torch::Tensor out_depth = torch::full({1, H, W}, 0.0, float_opts);

torch::Device device(torch::kCUDA);
torch::TensorOptions options(torch::kByte);
Expand Down Expand Up @@ -107,9 +108,10 @@ RasterizeGaussiansCUDA(
tan_fovy,
prefiltered,
out_color.contiguous().data<float>(),
out_depth.contiguous().data<float>(),
radii.contiguous().data<int>());
}
return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer);
return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer, out_depth);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
Expand Down
2 changes: 1 addition & 1 deletion rasterize_points.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include <tuple>
#include <string>

std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA(
const torch::Tensor& background,
const torch::Tensor& means3D,
Expand Down