Skip to content

Commit 0d8608b

Browse files
Jiali Duanfacebook-github-bot
Jiali Duan
authored andcommitted
Marching Cubes C++ torch extension
Summary: Torch C++ extension for Marching Cubes - Add torch C++ extension for marching cubes. Observe a speed up of ~255x-324x speed up (over varying batch sizes and spatial resolutions) - Add C++ impl in existing unit-tests. (Note: this ignores all push blocking failures!) Reviewed By: kjchalup Differential Revision: D39590638 fbshipit-source-id: e44d2852a24c2c398e5ea9db20f0dfaa1817e457
1 parent 850efdf commit 0d8608b

File tree

7 files changed

+879
-9
lines changed

7 files changed

+879
-9
lines changed

pytorch3d/csrc/ext.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "interp_face_attrs/interp_face_attrs.h"
2323
#include "iou_box3d/iou_box3d.h"
2424
#include "knn/knn.h"
25+
#include "marching_cubes/marching_cubes.h"
2526
#include "mesh_normal_consistency/mesh_normal_consistency.h"
2627
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
2728
#include "point_mesh/point_mesh_cuda.h"
@@ -94,6 +95,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
9495
// 3D IoU
9596
m.def("iou_box3d", &IoUBox3D);
9697

98+
// Marching cubes
99+
m.def("marching_cubes", &MarchingCubes);
100+
97101
// Pulsar.
98102
#ifdef PULSAR_LOGGING_ENABLED
99103
c10::ShowLogInfoToStderr();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
#include <torch/extension.h>
11+
#include <tuple>
12+
#include <vector>
13+
#include "utils/pytorch3d_cutils.h"
14+
15+
// Run Marching Cubes algorithm over a batch of volume scalar fields
16+
// with a pre-defined threshold and return a mesh composed of vertices
17+
// and faces for the mesh.
18+
//
19+
// Args:
20+
// vol: FloatTensor of shape (D, H, W) giving a volume
21+
// scalar grids.
22+
// isolevel: isosurface value to use as the threshoold to determine whether
23+
// the points are within a volume.
24+
//
25+
// Returns:
26+
// vertices: List of N FloatTensors of vertices
27+
// faces: List of N LongTensors of faces
28+
29+
// CPU implementation
30+
std::tuple<at::Tensor, at::Tensor> MarchingCubesCpu(
31+
const at::Tensor& vol,
32+
const float isolevel);
33+
34+
// Implementation which is exposed
35+
inline std::tuple<at::Tensor, at::Tensor> MarchingCubes(
36+
const at::Tensor& vol,
37+
const float isolevel) {
38+
return MarchingCubesCpu(vol.contiguous(), isolevel);
39+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <torch/extension.h>
10+
#include <algorithm>
11+
#include <array>
12+
#include <cstring>
13+
#include <unordered_map>
14+
#include <vector>
15+
#include "marching_cubes/marching_cubes_utils.h"
16+
17+
// Cpu implementation for Marching Cubes
18+
// Args:
19+
// vol: a Tensor of size (D, H, W) corresponding to a 3D scalar field
20+
// isolevel: the isosurface value to use as the threshold to determine
21+
// whether points are within a volume.
22+
//
23+
// Returns:
24+
// vertices: a float tensor of shape (N, 3) for positions of the mesh
25+
// faces: a long tensor of shape (N, 3) for indices of the face vertices
26+
//
27+
std::tuple<at::Tensor, at::Tensor> MarchingCubesCpu(
28+
const at::Tensor& vol,
29+
const float isolevel) {
30+
// volume shapes
31+
const int D = vol.size(0);
32+
const int H = vol.size(1);
33+
const int W = vol.size(2);
34+
35+
// Create tensor accessors
36+
auto vol_a = vol.accessor<float, 3>();
37+
// vpair_to_edge maps a pair of vertex ids to its corresponding edge id
38+
std::unordered_map<std::pair<int, int>, int64_t> vpair_to_edge;
39+
// edge_id_to_v maps from an edge id to a vertex position
40+
std::unordered_map<int64_t, Vertex> edge_id_to_v;
41+
// uniq_edge_id: used to remove redundant edge ids
42+
std::unordered_map<int64_t, int64_t> uniq_edge_id;
43+
std::vector<int64_t> faces; // store face indices
44+
std::vector<Vertex> verts; // store vertex positions
45+
// enumerate each cell in the 3d grid
46+
for (int z = 0; z < D - 1; z++) {
47+
for (int y = 0; y < H - 1; y++) {
48+
for (int x = 0; x < W - 1; x++) {
49+
Cube cube(x, y, z, vol_a, isolevel);
50+
// Cube is entirely in/out of the surface
51+
if (_FACE_TABLE[cube.cubeindex][0] == -1) {
52+
continue;
53+
}
54+
// store all boundary vertices that intersect with the edges
55+
std::array<Vertex, 12> interp_points;
56+
// triangle vertex IDs and positions
57+
std::vector<int64_t> tri;
58+
std::vector<Vertex> ps;
59+
60+
// Interpolate the vertices where the surface intersects with the cube
61+
for (int j = 0; _FACE_TABLE[cube.cubeindex][j] != -1; j++) {
62+
const int e = _FACE_TABLE[cube.cubeindex][j];
63+
interp_points[e] = cube.VertexInterp(isolevel, e, vol_a);
64+
65+
auto vpair = cube.GetVPairFromEdge(e, W, H);
66+
if (!vpair_to_edge.count(vpair)) {
67+
vpair_to_edge[vpair] = vpair_to_edge.size();
68+
}
69+
70+
int64_t edge = vpair_to_edge[vpair];
71+
tri.push_back(edge);
72+
ps.push_back(interp_points[e]);
73+
74+
// Check if the triangle face is degenerate. A triangle face
75+
// is degenerate if any of the two verices share the same 3D position
76+
if ((j + 1) % 3 == 0 && ps[0] != ps[1] && ps[1] != ps[2] &&
77+
ps[2] != ps[0]) {
78+
for (int k = 0; k < 3; k++) {
79+
int v = tri[k];
80+
edge_id_to_v[tri.at(k)] = ps.at(k);
81+
if (!uniq_edge_id.count(v)) {
82+
uniq_edge_id[v] = verts.size();
83+
verts.push_back(edge_id_to_v[v]);
84+
}
85+
faces.push_back(uniq_edge_id[v]);
86+
}
87+
tri.clear();
88+
ps.clear();
89+
}
90+
} // endif
91+
} // endfor x
92+
} // endfor y
93+
} // endfor z
94+
// Collect returning tensor
95+
const int n_vertices = verts.size();
96+
const int64_t n_faces = (int64_t)faces.size() / 3;
97+
auto vert_tensor = torch::zeros({n_vertices, 3}, torch::kFloat);
98+
auto face_tensor = torch::zeros({n_faces, 3}, torch::kInt64);
99+
100+
auto vert_a = vert_tensor.accessor<float, 2>();
101+
for (int i = 0; i < n_vertices; i++) {
102+
vert_a[i][0] = verts.at(i).x;
103+
vert_a[i][1] = verts.at(i).y;
104+
vert_a[i][2] = verts.at(i).z;
105+
}
106+
107+
auto face_a = face_tensor.accessor<int64_t, 2>();
108+
for (int64_t i = 0; i < n_faces; i++) {
109+
face_a[i][0] = faces.at(i * 3 + 0);
110+
face_a[i][1] = faces.at(i * 3 + 1);
111+
face_a[i][2] = faces.at(i * 3 + 2);
112+
}
113+
114+
return std::make_tuple(vert_tensor, face_tensor);
115+
}

0 commit comments

Comments
 (0)