Skip to content

Commit 53266ec

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
C++ IoU for 3D Boxes
Summary: C++ Implementation of algorithm to compute 3D bounding boxes for batches of bboxes of shape (N, 8, 3) and (M, 8, 3). Reviewed By: gkioxari Differential Revision: D30905190 fbshipit-source-id: 02e2cf025cd4fa3ff706ce5cf9b82c0fb5443f96
1 parent 2293f1f commit 53266ec

File tree

7 files changed

+927
-29
lines changed

7 files changed

+927
-29
lines changed

pytorch3d/csrc/ext.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "face_areas_normals/face_areas_normals.h"
2121
#include "gather_scatter/gather_scatter.h"
2222
#include "interp_face_attrs/interp_face_attrs.h"
23+
#include "iou_box3d/iou_box3d.h"
2324
#include "knn/knn.h"
2425
#include "mesh_normal_consistency/mesh_normal_consistency.h"
2526
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
@@ -87,6 +88,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
8788
// Sample PDF
8889
m.def("sample_pdf", &SamplePdf);
8990

91+
// 3D IoU
92+
m.def("iou_box3d", &IoUBox3D);
93+
9094
// Pulsar.
9195
#ifdef PULSAR_LOGGING_ENABLED
9296
c10::ShowLogInfoToStderr();

pytorch3d/csrc/iou_box3d/iou_box3d.h

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
#include <torch/extension.h>
11+
#include <tuple>
12+
#include "utils/pytorch3d_cutils.h"
13+
14+
// Calculate the intersection volume and IoU metric for two batches of boxes
15+
//
16+
// Args:
17+
// boxes1: tensor of shape (N, 8, 3) of the coordinates of the 1st boxes
18+
// boxes2: tensor of shape (M, 8, 3) of the coordinates of the 2nd boxes
19+
// Returns:
20+
// vol: (N, M) tensor of the volume of the intersecting convex shapes
21+
// iou: (N, M) tensor of the intersection over union which is
22+
// defined as: `iou = vol / (vol1 + vol2 - vol)`
23+
24+
// CPU implementation
25+
std::tuple<at::Tensor, at::Tensor> IoUBox3DCpu(
26+
const at::Tensor& boxes1,
27+
const at::Tensor& boxes2);
28+
29+
// Implementation which is exposed
30+
inline std::tuple<at::Tensor, at::Tensor> IoUBox3D(
31+
const at::Tensor& boxes1,
32+
const at::Tensor& boxes2) {
33+
if (boxes1.is_cuda() || boxes2.is_cuda()) {
34+
AT_ERROR("GPU support not implemented");
35+
}
36+
return IoUBox3DCpu(boxes1.contiguous(), boxes2.contiguous());
37+
}
+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <torch/extension.h>
10+
#include <torch/torch.h>
11+
#include <list>
12+
#include <numeric>
13+
#include <queue>
14+
#include <tuple>
15+
#include "iou_box3d/iou_utils.h"
16+
17+
std::tuple<at::Tensor, at::Tensor> IoUBox3DCpu(
18+
const at::Tensor& boxes1,
19+
const at::Tensor& boxes2) {
20+
const int N = boxes1.size(0);
21+
const int M = boxes2.size(0);
22+
auto float_opts = boxes1.options().dtype(torch::kFloat32);
23+
torch::Tensor vols = torch::zeros({N, M}, float_opts);
24+
torch::Tensor ious = torch::zeros({N, M}, float_opts);
25+
26+
// Create tensor accessors
27+
auto boxes1_a = boxes1.accessor<float, 3>();
28+
auto boxes2_a = boxes2.accessor<float, 3>();
29+
auto vols_a = vols.accessor<float, 2>();
30+
auto ious_a = ious.accessor<float, 2>();
31+
32+
// Iterate through the N boxes in boxes1
33+
for (int n = 0; n < N; ++n) {
34+
const auto& box1 = boxes1_a[n];
35+
// Convert to vector of face vertices i.e. effectively (F, 3, 3)
36+
// face_verts is a data type defined in iou_utils.h
37+
const face_verts box1_tris = GetBoxTris(box1);
38+
39+
// Calculate the position of the center of the box which is used in
40+
// several calculations. This requires a tensor as input.
41+
const vec3<float> box1_center = BoxCenter(boxes1[n]);
42+
43+
// Convert to vector of face vertices i.e. effectively (P, 4, 3)
44+
const face_verts box1_planes = GetBoxPlanes(box1);
45+
46+
// Get Box Volumes
47+
const float box1_vol = BoxVolume(box1_tris, box1_center);
48+
49+
// Iterate through the M boxes in boxes2
50+
for (int m = 0; m < M; ++m) {
51+
// Repeat above steps for box2
52+
// TODO: check if caching these value helps performance.
53+
const auto& box2 = boxes2_a[m];
54+
const face_verts box2_tris = GetBoxTris(box2);
55+
const vec3<float> box2_center = BoxCenter(boxes2[m]);
56+
const face_verts box2_planes = GetBoxPlanes(box2);
57+
const float box2_vol = BoxVolume(box2_tris, box2_center);
58+
59+
// Every triangle in one box will be compared to each plane in the other
60+
// box. There are 3 possible outcomes:
61+
// 1. If the triangle is fully inside, then it will
62+
// remain as is.
63+
// 2. If the triagnle it is fully outside, it will be removed.
64+
// 3. If the triangle intersects with the (infinite) plane, it
65+
// will be broken into subtriangles such that each subtriangle is full
66+
// inside the plane and part of the intersecting tetrahedron.
67+
68+
// Tris in Box1 -> Planes in Box2
69+
face_verts box1_intersect =
70+
BoxIntersections(box1_tris, box2_planes, box2_center);
71+
// Tris in Box2 -> Planes in Box1
72+
face_verts box2_intersect =
73+
BoxIntersections(box2_tris, box1_planes, box1_center);
74+
75+
// If there are overlapping regions in Box2, remove any coplanar faces
76+
if (box2_intersect.size() > 0) {
77+
// Identify if any triangles in Box2 are coplanar with Box1
78+
std::vector<int> tri2_keep(box2_intersect.size());
79+
std::fill(tri2_keep.begin(), tri2_keep.end(), 1);
80+
for (int b1 = 0; b1 < box1_intersect.size(); ++b1) {
81+
for (int b2 = 0; b2 < box2_intersect.size(); ++b2) {
82+
bool is_coplanar =
83+
IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]);
84+
if (is_coplanar) {
85+
tri2_keep[b2] = 0;
86+
}
87+
}
88+
}
89+
90+
// Keep only the non coplanar triangles in Box2 - add them to the
91+
// Box1 triangles.
92+
for (int b2 = 0; b2 < box2_intersect.size(); ++b2) {
93+
if (tri2_keep[b2] == 1) {
94+
box1_intersect.push_back((box2_intersect[b2]));
95+
}
96+
}
97+
}
98+
99+
// Initialize the vol and iou to 0.0 in case there are no triangles
100+
// in the intersecting shape.
101+
float vol = 0.0;
102+
float iou = 0.0;
103+
104+
// If there are triangles in the intersecting shape
105+
if (box1_intersect.size() > 0) {
106+
// The intersecting shape is a polyhedron made up of the
107+
// triangular faces that are all now in box1_intersect.
108+
// Calculate the polyhedron center
109+
const vec3<float> polyhedron_center = PolyhedronCenter(box1_intersect);
110+
// Compute intersecting polyhedron volume
111+
vol = BoxVolume(box1_intersect, polyhedron_center);
112+
// Compute IoU
113+
iou = vol / (box1_vol + box2_vol - vol);
114+
}
115+
// Save out volume and IoU
116+
vols_a[n][m] = vol;
117+
ious_a[n][m] = iou;
118+
}
119+
}
120+
return std::make_tuple(vols, ious);
121+
}

0 commit comments

Comments
 (0)