Skip to content

Commit 6471893

Browse files
Gavin Pengfacebook-github-bot
Gavin Peng
authored andcommitted
Multithread CPU naive mesh rasterization
Summary: Threaded the for loop: ``` for (int yi = 0; yi < H; ++yi) {...} ``` in function `RasterizeMeshesNaiveCpu()`. Chunk size is approx equal. Reviewed By: bottler Differential Revision: D40063604 fbshipit-source-id: 09150269405538119b0f1b029892179501421e68
1 parent 37bd280 commit 6471893

File tree

3 files changed

+121
-47
lines changed

3 files changed

+121
-47
lines changed

pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp

+106-44
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
#include <algorithm>
1111
#include <list>
1212
#include <queue>
13+
#include <thread>
1314
#include <tuple>
15+
#include "ATen/core/TensorAccessor.h"
1416
#include "rasterize_points/rasterization_utils.h"
1517
#include "utils/geometry_utils.h"
1618
#include "utils/vec2.h"
@@ -117,54 +119,28 @@ struct IsNeighbor {
117119
int neighbor_idx;
118120
};
119121

120-
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
121-
RasterizeMeshesNaiveCpu(
122-
const torch::Tensor& face_verts,
122+
namespace {
123+
void RasterizeMeshesNaiveCpu_worker(
124+
const int start_yi,
125+
const int end_yi,
123126
const torch::Tensor& mesh_to_face_first_idx,
124127
const torch::Tensor& num_faces_per_mesh,
125-
const torch::Tensor& clipped_faces_neighbor_idx,
126-
const std::tuple<int, int> image_size,
127128
const float blur_radius,
128-
const int faces_per_pixel,
129129
const bool perspective_correct,
130130
const bool clip_barycentric_coords,
131-
const bool cull_backfaces) {
132-
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
133-
face_verts.size(2) != 3) {
134-
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
135-
}
136-
if (num_faces_per_mesh.size(0) != mesh_to_face_first_idx.size(0)) {
137-
AT_ERROR(
138-
"num_faces_per_mesh must have save size first dimension as mesh_to_face_first_idx");
139-
}
140-
141-
const int32_t N = mesh_to_face_first_idx.size(0); // batch_size.
142-
const int H = std::get<0>(image_size);
143-
const int W = std::get<1>(image_size);
144-
const int K = faces_per_pixel;
145-
146-
auto long_opts = num_faces_per_mesh.options().dtype(torch::kInt64);
147-
auto float_opts = face_verts.options().dtype(torch::kFloat32);
148-
149-
// Initialize output tensors.
150-
torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
151-
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
152-
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
153-
torch::Tensor barycentric_coords =
154-
torch::full({N, H, W, K, 3}, -1, float_opts);
155-
156-
auto face_verts_a = face_verts.accessor<float, 3>();
157-
auto face_idxs_a = face_idxs.accessor<int64_t, 4>();
158-
auto zbuf_a = zbuf.accessor<float, 4>();
159-
auto pix_dists_a = pix_dists.accessor<float, 4>();
160-
auto barycentric_coords_a = barycentric_coords.accessor<float, 5>();
161-
auto neighbor_idx_a = clipped_faces_neighbor_idx.accessor<int64_t, 1>();
162-
163-
auto face_bboxes = ComputeFaceBoundingBoxes(face_verts);
164-
auto face_bboxes_a = face_bboxes.accessor<float, 2>();
165-
auto face_areas = ComputeFaceAreas(face_verts);
166-
auto face_areas_a = face_areas.accessor<float, 1>();
167-
131+
const bool cull_backfaces,
132+
const int32_t N,
133+
const int H,
134+
const int W,
135+
const int K,
136+
at::TensorAccessor<float, 3>& face_verts_a,
137+
at::TensorAccessor<float, 1>& face_areas_a,
138+
at::TensorAccessor<float, 2>& face_bboxes_a,
139+
at::TensorAccessor<int64_t, 1>& neighbor_idx_a,
140+
at::TensorAccessor<float, 4>& zbuf_a,
141+
at::TensorAccessor<int64_t, 4>& face_idxs_a,
142+
at::TensorAccessor<float, 4>& pix_dists_a,
143+
at::TensorAccessor<float, 5>& barycentric_coords_a) {
168144
for (int n = 0; n < N; ++n) {
169145
// Loop through each mesh in the batch.
170146
// Get the start index of the faces in faces_packed and the num faces
@@ -174,7 +150,7 @@ RasterizeMeshesNaiveCpu(
174150
(face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
175151

176152
// Iterate through the horizontal lines of the image from top to bottom.
177-
for (int yi = 0; yi < H; ++yi) {
153+
for (int yi = start_yi; yi < end_yi; ++yi) {
178154
// Reverse the order of yi so that +Y is pointing upwards in the image.
179155
const int yidx = H - 1 - yi;
180156

@@ -324,6 +300,92 @@ RasterizeMeshesNaiveCpu(
324300
}
325301
}
326302
}
303+
}
304+
} // namespace
305+
306+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
307+
RasterizeMeshesNaiveCpu(
308+
const torch::Tensor& face_verts,
309+
const torch::Tensor& mesh_to_face_first_idx,
310+
const torch::Tensor& num_faces_per_mesh,
311+
const torch::Tensor& clipped_faces_neighbor_idx,
312+
const std::tuple<int, int> image_size,
313+
const float blur_radius,
314+
const int faces_per_pixel,
315+
const bool perspective_correct,
316+
const bool clip_barycentric_coords,
317+
const bool cull_backfaces) {
318+
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
319+
face_verts.size(2) != 3) {
320+
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
321+
}
322+
if (num_faces_per_mesh.size(0) != mesh_to_face_first_idx.size(0)) {
323+
AT_ERROR(
324+
"num_faces_per_mesh must have save size first dimension as mesh_to_face_first_idx");
325+
}
326+
327+
const int32_t N = mesh_to_face_first_idx.size(0); // batch_size.
328+
const int H = std::get<0>(image_size);
329+
const int W = std::get<1>(image_size);
330+
const int K = faces_per_pixel;
331+
332+
auto long_opts = num_faces_per_mesh.options().dtype(torch::kInt64);
333+
auto float_opts = face_verts.options().dtype(torch::kFloat32);
334+
335+
// Initialize output tensors.
336+
torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
337+
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
338+
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
339+
torch::Tensor barycentric_coords =
340+
torch::full({N, H, W, K, 3}, -1, float_opts);
341+
342+
auto face_verts_a = face_verts.accessor<float, 3>();
343+
auto face_idxs_a = face_idxs.accessor<int64_t, 4>();
344+
auto zbuf_a = zbuf.accessor<float, 4>();
345+
auto pix_dists_a = pix_dists.accessor<float, 4>();
346+
auto barycentric_coords_a = barycentric_coords.accessor<float, 5>();
347+
auto neighbor_idx_a = clipped_faces_neighbor_idx.accessor<int64_t, 1>();
348+
349+
auto face_bboxes = ComputeFaceBoundingBoxes(face_verts);
350+
auto face_bboxes_a = face_bboxes.accessor<float, 2>();
351+
auto face_areas = ComputeFaceAreas(face_verts);
352+
auto face_areas_a = face_areas.accessor<float, 1>();
353+
354+
const int64_t n_threads = at::get_num_threads();
355+
std::vector<std::thread> threads;
356+
threads.reserve(n_threads);
357+
const int chunk_size = 1 + (H - 1) / n_threads;
358+
int start_yi = 0;
359+
for (int iThread = 0; iThread < n_threads; ++iThread) {
360+
const int64_t end_yi = std::min(start_yi + chunk_size, H);
361+
threads.emplace_back(
362+
RasterizeMeshesNaiveCpu_worker,
363+
start_yi,
364+
end_yi,
365+
mesh_to_face_first_idx,
366+
num_faces_per_mesh,
367+
blur_radius,
368+
perspective_correct,
369+
clip_barycentric_coords,
370+
cull_backfaces,
371+
N,
372+
H,
373+
W,
374+
K,
375+
std::ref(face_verts_a),
376+
std::ref(face_areas_a),
377+
std::ref(face_bboxes_a),
378+
std::ref(neighbor_idx_a),
379+
std::ref(zbuf_a),
380+
std::ref(face_idxs_a),
381+
std::ref(pix_dists_a),
382+
std::ref(barycentric_coords_a));
383+
start_yi += chunk_size;
384+
}
385+
for (auto&& thread : threads) {
386+
thread.join();
387+
}
388+
327389
return std::make_tuple(face_idxs, zbuf, barycentric_coords, pix_dists);
328390
}
329391

tests/benchmarks/bm_rasterize_meshes.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
7+
import os
88
from itertools import product
99

1010
import torch
1111
from fvcore.common.benchmark import benchmark
1212
from tests.test_rasterize_meshes import TestRasterizeMeshes
1313

14+
BM_RASTERIZE_MESHES_N_THREADS = os.getenv("BM_RASTERIZE_MESHES_N_THREADS", 1)
15+
torch.set_num_threads(int(BM_RASTERIZE_MESHES_N_THREADS))
1416

1517
# ico levels:
1618
# 0: (12 verts, 20 faces)
@@ -41,7 +43,7 @@ def bm_rasterize_meshes() -> None:
4143
kwargs_list = []
4244
num_meshes = [1]
4345
ico_level = [1]
44-
image_size = [64, 128]
46+
image_size = [64, 128, 512]
4547
blur = [1e-6]
4648
faces_per_pixel = [3, 50]
4749
test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel)

tests/test_rasterize_meshes.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,24 @@ def test_simple_python(self):
3535
self._test_barycentric_clipping(rasterize_meshes_python, device, bin_size=-1)
3636
self._test_back_face_culling(rasterize_meshes_python, device, bin_size=-1)
3737

38-
def test_simple_cpu_naive(self):
38+
def _test_simple_cpu_naive_instance(self):
3939
device = torch.device("cpu")
4040
self._simple_triangle_raster(rasterize_meshes, device, bin_size=0)
4141
self._simple_blurry_raster(rasterize_meshes, device, bin_size=0)
4242
self._test_behind_camera(rasterize_meshes, device, bin_size=0)
4343
self._test_perspective_correct(rasterize_meshes, device, bin_size=0)
4444
self._test_back_face_culling(rasterize_meshes, device, bin_size=0)
4545

46+
def test_simple_cpu_naive(self):
47+
n_threads = torch.get_num_threads()
48+
torch.set_num_threads(1) # single threaded
49+
self._test_simple_cpu_naive_instance()
50+
torch.set_num_threads(4) # even (divisible) number of threads
51+
self._test_simple_cpu_naive_instance()
52+
torch.set_num_threads(5) # odd (nondivisible) number of threads
53+
self._test_simple_cpu_naive_instance()
54+
torch.set_num_threads(n_threads)
55+
4656
def test_simple_cuda_naive(self):
4757
device = get_random_cuda_device()
4858
self._simple_triangle_raster(rasterize_meshes, device, bin_size=0)

0 commit comments

Comments
 (0)