Skip to content

Commit 4bf3059

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
back face culling in rasterization
Summary: Added backface culling as an option to the `raster_settings`. This is needed for the full forward rendering of shapenet meshes with texture (some meshes contain multiple overlapping segments which have different textures). For a triangle (v0, v1, v2) define the vectors A = (v1 - v0) and B = (v2 − v0) and use this to calculate the area of the triangle as: ``` area = 0.5 * A x B area = 0.5 * ((x1 − x0)(y2 − y0) − (x2 − x0)(y1 − y0)) ``` The area will be positive if (v0, v1, v2) are oriented counterclockwise (a front face), and negative if (v0, v1, v2) are oriented clockwise (a back face). We can reuse the `edge_function` as it already calculates the triangle area. Reviewed By: jcjohnson Differential Revision: D20960115 fbshipit-source-id: 2d8a4b9ccfb653df18e79aed8d05c7ec0f057ab1
1 parent 3c6f922 commit 4bf3059

File tree

7 files changed

+187
-30
lines changed

7 files changed

+187
-30
lines changed

pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu

+21-8
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ __device__ void CheckPixelInsideFace(
111111
const float blur_radius,
112112
const float2 pxy, // Coordinates of the pixel
113113
const int K,
114-
const bool perspective_correct) {
114+
const bool perspective_correct,
115+
const bool cull_backfaces) {
115116
const auto v012 = GetSingleFaceVerts(face_verts, face_idx);
116117
const float3 v0 = thrust::get<0>(v012);
117118
const float3 v1 = thrust::get<1>(v012);
@@ -124,16 +125,20 @@ __device__ void CheckPixelInsideFace(
124125

125126
// Perform checks and skip if:
126127
// 1. the face is behind the camera
127-
// 2. the face has very small face area
128-
// 3. the pixel is outside the face bbox
128+
// 2. the face is facing away from the camera
129+
// 3. the face has very small face area
130+
// 4. the pixel is outside the face bbox
129131
const float zmax = FloatMax3(v0.z, v1.z, v2.z);
130132
const bool outside_bbox = CheckPointOutsideBoundingBox(
131133
v0, v1, v2, sqrt(blur_radius), pxy); // use sqrt of blur for bbox
132134
const float face_area = EdgeFunctionForward(v0xy, v1xy, v2xy);
135+
// Check if the face is visible to the camera.
136+
const bool back_face = face_area < 0.0;
133137
const bool zero_face_area =
134138
(face_area <= kEpsilon && face_area >= -1.0f * kEpsilon);
135139

136-
if (zmax < 0 || outside_bbox || zero_face_area) {
140+
if (zmax < 0 || cull_backfaces && back_face || outside_bbox ||
141+
zero_face_area) {
137142
return;
138143
}
139144

@@ -191,6 +196,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
191196
const int64_t* num_faces_per_mesh,
192197
const float blur_radius,
193198
const bool perspective_correct,
199+
const bool cull_backfaces,
194200
const int N,
195201
const int H,
196202
const int W,
@@ -251,7 +257,8 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
251257
blur_radius,
252258
pxy,
253259
K,
254-
perspective_correct);
260+
perspective_correct,
261+
cull_backfaces);
255262
}
256263

257264
// TODO: make sorting an option as only top k is needed, not sorted values.
@@ -276,7 +283,8 @@ RasterizeMeshesNaiveCuda(
276283
const int image_size,
277284
const float blur_radius,
278285
const int num_closest,
279-
const bool perspective_correct) {
286+
const bool perspective_correct,
287+
const bool cull_backfaces) {
280288
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
281289
face_verts.size(2) != 3) {
282290
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
@@ -314,6 +322,7 @@ RasterizeMeshesNaiveCuda(
314322
num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
315323
blur_radius,
316324
perspective_correct,
325+
cull_backfaces,
317326
N,
318327
H,
319328
W,
@@ -667,6 +676,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
667676
const float blur_radius,
668677
const int bin_size,
669678
const bool perspective_correct,
679+
const bool cull_backfaces,
670680
const int N,
671681
const int B,
672682
const int M,
@@ -730,7 +740,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
730740
blur_radius,
731741
pxy,
732742
K,
733-
perspective_correct);
743+
perspective_correct,
744+
cull_backfaces);
734745
}
735746

736747
// Now we've looked at all the faces for this bin, so we can write
@@ -762,7 +773,8 @@ RasterizeMeshesFineCuda(
762773
const float blur_radius,
763774
const int bin_size,
764775
const int faces_per_pixel,
765-
const bool perspective_correct) {
776+
const bool perspective_correct,
777+
const bool cull_backfaces) {
766778
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
767779
face_verts.size(2) != 3) {
768780
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
@@ -797,6 +809,7 @@ RasterizeMeshesFineCuda(
797809
blur_radius,
798810
bin_size,
799811
perspective_correct,
812+
cull_backfaces,
800813
N,
801814
B,
802815
M,

pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h

+46-12
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ RasterizeMeshesNaiveCpu(
1717
const int image_size,
1818
const float blur_radius,
1919
const int faces_per_pixel,
20-
const bool perspective_correct);
20+
const bool perspective_correct,
21+
const bool cull_backfaces);
2122

2223
#ifdef WITH_CUDA
2324
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
@@ -28,7 +29,8 @@ RasterizeMeshesNaiveCuda(
2829
const int image_size,
2930
const float blur_radius,
3031
const int num_closest,
31-
const bool perspective_correct);
32+
const bool perspective_correct,
33+
const bool cull_backfaces);
3234
#endif
3335
// Forward pass for rasterizing a batch of meshes.
3436
//
@@ -55,6 +57,14 @@ RasterizeMeshesNaiveCuda(
5557
// coordinates for each pixel; if this is False then
5658
// this function instead returns screen-space
5759
// barycentric coordinates for each pixel.
60+
// cull_backfaces: Bool, Whether to only rasterize mesh faces which are
61+
// visible to the camera. This assumes that vertices of
62+
// front-facing triangles are ordered in an anti-clockwise
63+
// fashion, and triangles that face away from the camera are
64+
// in a clockwise order relative to the current view
65+
// direction. NOTE: This will only work if the mesh faces are
66+
// consistently defined with counter-clockwise ordering when
67+
// viewed from the outside.
5868
//
5969
// Returns:
6070
// A 4 element tuple of:
@@ -80,7 +90,8 @@ RasterizeMeshesNaive(
8090
const int image_size,
8191
const float blur_radius,
8292
const int faces_per_pixel,
83-
const bool perspective_correct) {
93+
const bool perspective_correct,
94+
const bool cull_backfaces) {
8495
// TODO: Better type checking.
8596
if (face_verts.is_cuda()) {
8697
#ifdef WITH_CUDA
@@ -91,7 +102,8 @@ RasterizeMeshesNaive(
91102
image_size,
92103
blur_radius,
93104
faces_per_pixel,
94-
perspective_correct);
105+
perspective_correct,
106+
cull_backfaces);
95107
#else
96108
AT_ERROR("Not compiled with GPU support");
97109
#endif
@@ -103,7 +115,8 @@ RasterizeMeshesNaive(
103115
image_size,
104116
blur_radius,
105117
faces_per_pixel,
106-
perspective_correct);
118+
perspective_correct,
119+
cull_backfaces);
107120
}
108121
}
109122

@@ -274,7 +287,8 @@ RasterizeMeshesFineCuda(
274287
const float blur_radius,
275288
const int bin_size,
276289
const int faces_per_pixel,
277-
const bool perspective_correct);
290+
const bool perspective_correct,
291+
const bool cull_backfaces);
278292
#endif
279293
// Args:
280294
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
@@ -296,6 +310,14 @@ RasterizeMeshesFineCuda(
296310
// coordinates for each pixel; if this is False then
297311
// this function instead returns screen-space
298312
// barycentric coordinates for each pixel.
313+
// cull_backfaces: Bool, Whether to only rasterize mesh faces which are
314+
// visible to the camera. This assumes that vertices of
315+
// front-facing triangles are ordered in an anti-clockwise
316+
// fashion, and triangles that face away from the camera are
317+
// in a clockwise order relative to the current view
318+
// direction. NOTE: This will only work if the mesh faces are
319+
// consistently defined with counter-clockwise ordering when
320+
// viewed from the outside.
299321
//
300322
// Returns (same as rasterize_meshes):
301323
// A 4 element tuple of:
@@ -321,7 +343,8 @@ RasterizeMeshesFine(
321343
const float blur_radius,
322344
const int bin_size,
323345
const int faces_per_pixel,
324-
const bool perspective_correct) {
346+
const bool perspective_correct,
347+
const bool cull_backfaces) {
325348
if (face_verts.is_cuda()) {
326349
#ifdef WITH_CUDA
327350
return RasterizeMeshesFineCuda(
@@ -331,7 +354,8 @@ RasterizeMeshesFine(
331354
blur_radius,
332355
bin_size,
333356
faces_per_pixel,
334-
perspective_correct);
357+
perspective_correct,
358+
cull_backfaces);
335359
#else
336360
AT_ERROR("Not compiled with GPU support");
337361
#endif
@@ -372,7 +396,14 @@ RasterizeMeshesFine(
372396
// coordinates for each pixel; if this is False then
373397
// this function instead returns screen-space
374398
// barycentric coordinates for each pixel.
375-
//
399+
// cull_backfaces: Bool, Whether to only rasterize mesh faces which are
400+
// visible to the camera. This assumes that vertices of
401+
// front-facing triangles are ordered in an anti-clockwise
402+
// fashion, and triangles that face away from the camera are
403+
// in a clockwise order relative to the current view
404+
// direction. NOTE: This will only work if the mesh faces are
405+
// consistently defined with counter-clockwise ordering when
406+
// viewed from the outside.
376407
//
377408
// Returns:
378409
// A 4 element tuple of:
@@ -400,7 +431,8 @@ RasterizeMeshes(
400431
const int faces_per_pixel,
401432
const int bin_size,
402433
const int max_faces_per_bin,
403-
const bool perspective_correct) {
434+
const bool perspective_correct,
435+
const bool cull_backfaces) {
404436
if (bin_size > 0 && max_faces_per_bin > 0) {
405437
// Use coarse-to-fine rasterization
406438
auto bin_faces = RasterizeMeshesCoarse(
@@ -418,7 +450,8 @@ RasterizeMeshes(
418450
blur_radius,
419451
bin_size,
420452
faces_per_pixel,
421-
perspective_correct);
453+
perspective_correct,
454+
cull_backfaces);
422455
} else {
423456
// Use the naive per-pixel implementation
424457
return RasterizeMeshesNaive(
@@ -428,6 +461,7 @@ RasterizeMeshes(
428461
image_size,
429462
blur_radius,
430463
faces_per_pixel,
431-
perspective_correct);
464+
perspective_correct,
465+
cull_backfaces);
432466
}
433467
}

pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ RasterizeMeshesNaiveCpu(
107107
int image_size,
108108
const float blur_radius,
109109
const int faces_per_pixel,
110-
const bool perspective_correct) {
110+
const bool perspective_correct,
111+
const bool cull_backfaces) {
111112
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
112113
face_verts.size(2) != 3) {
113114
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
@@ -184,8 +185,13 @@ RasterizeMeshesNaiveCpu(
184185
const vec2<float> v1(x1, y1);
185186
const vec2<float> v2(x2, y2);
186187

187-
// Skip faces with zero area.
188188
const float face_area = face_areas_a[f];
189+
const bool back_face = face_area < 0.0;
190+
// Check if the face is visible to the camera.
191+
if (cull_backfaces && back_face) {
192+
continue;
193+
}
194+
// Skip faces with zero area.
189195
if (face_area <= kEpsilon && face_area >= -1.0f * kEpsilon) {
190196
continue;
191197
}

pytorch3d/io/obj_io.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -140,16 +140,16 @@ def load_obj(f_obj, load_textures=True):
140140
If there are faces with more than 3 vertices
141141
they are subdivided into triangles. Polygonal faces are assummed to have
142142
vertices ordered counter-clockwise so the (right-handed) normal points
143-
into the screen e.g. a proper rectangular face would be specified like this:
143+
out of the screen e.g. a proper rectangular face would be specified like this:
144144
::
145145
0_________1
146146
| |
147147
| |
148148
3 ________2
149149
150-
The face would be split into two triangles: (0, 1, 2) and (0, 2, 3),
151-
both of which are also oriented clockwise and have normals
152-
pointing into the screen.
150+
The face would be split into two triangles: (0, 2, 1) and (0, 3, 2),
151+
both of which are also oriented counter-clockwise and have normals
152+
pointing out of the screen.
153153
154154
Args:
155155
f: A file-like object (with methods read, readline, tell, and seek),

0 commit comments

Comments
 (0)