Skip to content

Commit 62dbf37

Browse files
jcjohnsonfacebook-github-bot
authored andcommitted
Move coarse rasterization to new file
Summary: In preparation for sharing coarse rasterization between point clouds and meshes, move the functions to a new file. No code changes. Reviewed By: bottler Differential Revision: D30367812 fbshipit-source-id: 9e73835a26c4ac91f5c9f61ff682bc8218e36c6a
1 parent f2c44e3 commit 62dbf37

File tree

8 files changed

+534
-480
lines changed

8 files changed

+534
-480
lines changed

pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu

+481
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <torch/extension.h>
12+
#include <tuple>
13+
14+
// Arguments are the same as RasterizeMeshesCoarse from
15+
// rasterize_meshes/rasterize_meshes.h
16+
#ifdef WITH_CUDA
17+
torch::Tensor RasterizeMeshesCoarseCuda(
18+
const torch::Tensor& face_verts,
19+
const torch::Tensor& mesh_to_face_first_idx,
20+
const torch::Tensor& num_faces_per_mesh,
21+
const std::tuple<int, int> image_size,
22+
const float blur_radius,
23+
const int bin_size,
24+
const int max_faces_per_bin);
25+
#endif
26+
27+
// Arguments are the same as RasterizePointsCoarse from
28+
// rasterize_points/rasterize_points.h
29+
#ifdef WITH_CUDA
30+
torch::Tensor RasterizePointsCoarseCuda(
31+
const torch::Tensor& points,
32+
const torch::Tensor& cloud_to_packed_first_idx,
33+
const torch::Tensor& num_points_per_cloud,
34+
const std::tuple<int, int> image_size,
35+
const torch::Tensor& radius,
36+
const int bin_size,
37+
const int max_points_per_bin);
38+
#endif

pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu

-233
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#include <thrust/tuple.h>
1515
#include <cstdio>
1616
#include <tuple>
17-
#include "rasterize_points/bitmask.cuh"
1817
#include "rasterize_points/rasterization_utils.cuh"
1918
#include "utils/float_math.cuh"
2019
#include "utils/geometry_utils.cuh"
@@ -32,14 +31,6 @@ __device__ bool operator<(const Pixel& a, const Pixel& b) {
3231
return a.z < b.z;
3332
}
3433

35-
__device__ float FloatMin3(const float p1, const float p2, const float p3) {
36-
return fminf(p1, fminf(p2, p3));
37-
}
38-
39-
__device__ float FloatMax3(const float p1, const float p2, const float p3) {
40-
return fmaxf(p1, fmaxf(p2, p3));
41-
}
42-
4334
// Get the xyz coordinates of the three vertices for the face given by the
4435
// index face_idx into face_verts.
4536
__device__ thrust::tuple<float3, float3, float3> GetSingleFaceVerts(
@@ -630,230 +621,6 @@ at::Tensor RasterizeMeshesBackwardCuda(
630621
return grad_face_verts;
631622
}
632623

633-
// ****************************************************************************
634-
// * COARSE RASTERIZATION *
635-
// ****************************************************************************
636-
637-
__global__ void RasterizeMeshesCoarseCudaKernel(
638-
const float* face_verts,
639-
const int64_t* mesh_to_face_first_idx,
640-
const int64_t* num_faces_per_mesh,
641-
const float blur_radius,
642-
const int N,
643-
const int F,
644-
const int H,
645-
const int W,
646-
const int bin_size,
647-
const int chunk_size,
648-
const int max_faces_per_bin,
649-
int* faces_per_bin,
650-
int* bin_faces) {
651-
extern __shared__ char sbuf[];
652-
const int M = max_faces_per_bin;
653-
// Integer divide round up
654-
const int num_bins_x = 1 + (W - 1) / bin_size;
655-
const int num_bins_y = 1 + (H - 1) / bin_size;
656-
657-
// NDC range depends on the ratio of W/H
658-
// The shorter side from (H, W) is given an NDC range of 2.0 and
659-
// the other side is scaled by the ratio of H:W.
660-
const float NDC_x_half_range = NonSquareNdcRange(W, H) / 2.0f;
661-
const float NDC_y_half_range = NonSquareNdcRange(H, W) / 2.0f;
662-
663-
// Size of half a pixel in NDC units is the NDC half range
664-
// divided by the corresponding image dimension
665-
const float half_pix_x = NDC_x_half_range / W;
666-
const float half_pix_y = NDC_y_half_range / H;
667-
668-
// This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size)
669-
// stored in shared memory that will track whether each point in the chunk
670-
// falls into each bin of the image.
671-
BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size);
672-
673-
// Have each block handle a chunk of faces
674-
const int chunks_per_batch = 1 + (F - 1) / chunk_size;
675-
const int num_chunks = N * chunks_per_batch;
676-
677-
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
678-
const int batch_idx = chunk / chunks_per_batch; // batch index
679-
const int chunk_idx = chunk % chunks_per_batch;
680-
const int face_start_idx = chunk_idx * chunk_size;
681-
682-
binmask.block_clear();
683-
const int64_t mesh_face_start_idx = mesh_to_face_first_idx[batch_idx];
684-
const int64_t mesh_face_stop_idx =
685-
mesh_face_start_idx + num_faces_per_mesh[batch_idx];
686-
687-
// Have each thread handle a different face within the chunk
688-
for (int f = threadIdx.x; f < chunk_size; f += blockDim.x) {
689-
const int f_idx = face_start_idx + f;
690-
691-
// Check if face index corresponds to the mesh in the batch given by
692-
// batch_idx
693-
if (f_idx >= mesh_face_stop_idx || f_idx < mesh_face_start_idx) {
694-
continue;
695-
}
696-
697-
// Get xyz coordinates of the three face vertices.
698-
const auto v012 = GetSingleFaceVerts(face_verts, f_idx);
699-
const float3 v0 = thrust::get<0>(v012);
700-
const float3 v1 = thrust::get<1>(v012);
701-
const float3 v2 = thrust::get<2>(v012);
702-
703-
// Compute screen-space bbox for the triangle expanded by blur.
704-
float xmin = FloatMin3(v0.x, v1.x, v2.x) - sqrt(blur_radius);
705-
float ymin = FloatMin3(v0.y, v1.y, v2.y) - sqrt(blur_radius);
706-
float xmax = FloatMax3(v0.x, v1.x, v2.x) + sqrt(blur_radius);
707-
float ymax = FloatMax3(v0.y, v1.y, v2.y) + sqrt(blur_radius);
708-
float zmin = FloatMin3(v0.z, v1.z, v2.z);
709-
710-
// Faces with at least one vertex behind the camera won't render
711-
// correctly and should be removed or clipped before calling the
712-
// rasterizer
713-
if (zmin < kEpsilon) {
714-
continue;
715-
}
716-
717-
// Brute-force search over all bins; TODO(T54294966) something smarter.
718-
for (int by = 0; by < num_bins_y; ++by) {
719-
// Y coordinate of the top and bottom of the bin.
720-
// PixToNdc gives the location of the center of each pixel, so we
721-
// need to add/subtract a half pixel to get the true extent of the bin.
722-
// Reverse ordering of Y axis so that +Y is upwards in the image.
723-
const float bin_y_min =
724-
PixToNonSquareNdc(by * bin_size, H, W) - half_pix_y;
725-
const float bin_y_max =
726-
PixToNonSquareNdc((by + 1) * bin_size - 1, H, W) + half_pix_y;
727-
const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax);
728-
729-
for (int bx = 0; bx < num_bins_x; ++bx) {
730-
// X coordinate of the left and right of the bin.
731-
// Reverse ordering of x axis so that +X is left.
732-
const float bin_x_max =
733-
PixToNonSquareNdc((bx + 1) * bin_size - 1, W, H) + half_pix_x;
734-
const float bin_x_min =
735-
PixToNonSquareNdc(bx * bin_size, W, H) - half_pix_x;
736-
737-
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
738-
if (y_overlap && x_overlap) {
739-
binmask.set(by, bx, f);
740-
}
741-
}
742-
}
743-
}
744-
__syncthreads();
745-
// Now we have processed every face in the current chunk. We need to
746-
// count the number of faces in each bin so we can write the indices
747-
// out to global memory. We have each thread handle a different bin.
748-
for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
749-
byx += blockDim.x) {
750-
const int by = byx / num_bins_x;
751-
const int bx = byx % num_bins_x;
752-
const int count = binmask.count(by, bx);
753-
const int faces_per_bin_idx =
754-
batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx;
755-
756-
// This atomically increments the (global) number of faces found
757-
// in the current bin, and gets the previous value of the counter;
758-
// this effectively allocates space in the bin_faces array for the
759-
// faces in the current chunk that fall into this bin.
760-
const int start = atomicAdd(faces_per_bin + faces_per_bin_idx, count);
761-
762-
// Now loop over the binmask and write the active bits for this bin
763-
// out to bin_faces.
764-
int next_idx = batch_idx * num_bins_y * num_bins_x * M +
765-
by * num_bins_x * M + bx * M + start;
766-
for (int f = 0; f < chunk_size; ++f) {
767-
if (binmask.get(by, bx, f)) {
768-
// TODO(T54296346) find the correct method for handling errors in
769-
// CUDA. Throw an error if num_faces_per_bin > max_faces_per_bin.
770-
// Either decrease bin size or increase max_faces_per_bin
771-
bin_faces[next_idx] = face_start_idx + f;
772-
next_idx++;
773-
}
774-
}
775-
}
776-
__syncthreads();
777-
}
778-
}
779-
780-
at::Tensor RasterizeMeshesCoarseCuda(
781-
const at::Tensor& face_verts,
782-
const at::Tensor& mesh_to_face_first_idx,
783-
const at::Tensor& num_faces_per_mesh,
784-
const std::tuple<int, int> image_size,
785-
const float blur_radius,
786-
const int bin_size,
787-
const int max_faces_per_bin) {
788-
TORCH_CHECK(
789-
face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
790-
face_verts.size(2) == 3,
791-
"face_verts must have dimensions (num_faces, 3, 3)");
792-
793-
// Check inputs are on the same device
794-
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
795-
mesh_to_face_first_idx_t{
796-
mesh_to_face_first_idx, "mesh_to_face_first_idx", 2},
797-
num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3};
798-
at::CheckedFrom c = "RasterizeMeshesCoarseCuda";
799-
at::checkAllSameGPU(
800-
c, {face_verts_t, mesh_to_face_first_idx_t, num_faces_per_mesh_t});
801-
802-
// Set the device for the kernel launch based on the device of the input
803-
at::cuda::CUDAGuard device_guard(face_verts.device());
804-
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
805-
806-
const int H = std::get<0>(image_size);
807-
const int W = std::get<1>(image_size);
808-
809-
const int F = face_verts.size(0);
810-
const int N = num_faces_per_mesh.size(0);
811-
const int M = max_faces_per_bin;
812-
813-
// Integer divide round up.
814-
const int num_bins_y = 1 + (H - 1) / bin_size;
815-
const int num_bins_x = 1 + (W - 1) / bin_size;
816-
817-
if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) {
818-
std::stringstream ss;
819-
ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y
820-
<< ", num_bins_x: " << num_bins_x << ", "
821-
<< "; that's too many!";
822-
AT_ERROR(ss.str());
823-
}
824-
auto opts = num_faces_per_mesh.options().dtype(at::kInt);
825-
at::Tensor faces_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts);
826-
at::Tensor bin_faces = at::full({N, num_bins_y, num_bins_x, M}, -1, opts);
827-
828-
if (bin_faces.numel() == 0) {
829-
AT_CUDA_CHECK(cudaGetLastError());
830-
return bin_faces;
831-
}
832-
833-
const int chunk_size = 512;
834-
const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8;
835-
const size_t blocks = 64;
836-
const size_t threads = 512;
837-
838-
RasterizeMeshesCoarseCudaKernel<<<blocks, threads, shared_size, stream>>>(
839-
face_verts.contiguous().data_ptr<float>(),
840-
mesh_to_face_first_idx.contiguous().data_ptr<int64_t>(),
841-
num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
842-
blur_radius,
843-
N,
844-
F,
845-
H,
846-
W,
847-
bin_size,
848-
chunk_size,
849-
M,
850-
faces_per_bin.data_ptr<int32_t>(),
851-
bin_faces.data_ptr<int32_t>());
852-
853-
AT_CUDA_CHECK(cudaGetLastError());
854-
return bin_faces;
855-
}
856-
857624
// ****************************************************************************
858625
// * FINE RASTERIZATION *
859626
// ****************************************************************************

pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h

+4-11
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <torch/extension.h>
1111
#include <cstdio>
1212
#include <tuple>
13+
#include "rasterize_coarse/rasterize_coarse.h"
1314
#include "utils/pytorch3d_cutils.h"
1415

1516
// ****************************************************************************
@@ -236,6 +237,8 @@ torch::Tensor RasterizeMeshesBackward(
236237
// * COARSE RASTERIZATION *
237238
// ****************************************************************************
238239

240+
// RasterizeMeshesCoarseCuda in rasterize_coarse/rasterize_coarse.h
241+
239242
torch::Tensor RasterizeMeshesCoarseCpu(
240243
const torch::Tensor& face_verts,
241244
const at::Tensor& mesh_to_face_first_idx,
@@ -245,16 +248,6 @@ torch::Tensor RasterizeMeshesCoarseCpu(
245248
const int bin_size,
246249
const int max_faces_per_bin);
247250

248-
#ifdef WITH_CUDA
249-
torch::Tensor RasterizeMeshesCoarseCuda(
250-
const torch::Tensor& face_verts,
251-
const torch::Tensor& mesh_to_face_first_idx,
252-
const torch::Tensor& num_faces_per_mesh,
253-
const std::tuple<int, int> image_size,
254-
const float blur_radius,
255-
const int bin_size,
256-
const int max_faces_per_bin);
257-
#endif
258251
// Args:
259252
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
260253
// faces in all the meshes in the batch. Concretely,
@@ -499,7 +492,7 @@ RasterizeMeshes(
499492
const bool cull_backfaces) {
500493
if (bin_size > 0 && max_faces_per_bin > 0) {
501494
// Use coarse-to-fine rasterization
502-
auto bin_faces = RasterizeMeshesCoarse(
495+
at::Tensor bin_faces = RasterizeMeshesCoarse(
503496
face_verts,
504497
mesh_to_face_first_idx,
505498
num_faces_per_mesh,

0 commit comments

Comments
 (0)