|
| 1 | +/* |
| 2 | + * Copyright (c) Facebook, Inc. and its affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include <torch/extension.h> |
| 10 | +#include <torch/torch.h> |
| 11 | +#include <list> |
| 12 | +#include <numeric> |
| 13 | +#include <queue> |
| 14 | +#include <tuple> |
| 15 | +#include "iou_box3d/iou_utils.h" |
| 16 | + |
| 17 | +std::tuple<at::Tensor, at::Tensor> IoUBox3DCpu( |
| 18 | + const at::Tensor& boxes1, |
| 19 | + const at::Tensor& boxes2) { |
| 20 | + const int N = boxes1.size(0); |
| 21 | + const int M = boxes2.size(0); |
| 22 | + auto float_opts = boxes1.options().dtype(torch::kFloat32); |
| 23 | + torch::Tensor vols = torch::zeros({N, M}, float_opts); |
| 24 | + torch::Tensor ious = torch::zeros({N, M}, float_opts); |
| 25 | + |
| 26 | + // Create tensor accessors |
| 27 | + auto boxes1_a = boxes1.accessor<float, 3>(); |
| 28 | + auto boxes2_a = boxes2.accessor<float, 3>(); |
| 29 | + auto vols_a = vols.accessor<float, 2>(); |
| 30 | + auto ious_a = ious.accessor<float, 2>(); |
| 31 | + |
| 32 | + // Iterate through the N boxes in boxes1 |
| 33 | + for (int n = 0; n < N; ++n) { |
| 34 | + const auto& box1 = boxes1_a[n]; |
| 35 | + // Convert to vector of face vertices i.e. effectively (F, 3, 3) |
| 36 | + // face_verts is a data type defined in iou_utils.h |
| 37 | + const face_verts box1_tris = GetBoxTris(box1); |
| 38 | + |
| 39 | + // Calculate the position of the center of the box which is used in |
| 40 | + // several calculations. This requires a tensor as input. |
| 41 | + const vec3<float> box1_center = BoxCenter(boxes1[n]); |
| 42 | + |
| 43 | + // Convert to vector of face vertices i.e. effectively (P, 4, 3) |
| 44 | + const face_verts box1_planes = GetBoxPlanes(box1); |
| 45 | + |
| 46 | + // Get Box Volumes |
| 47 | + const float box1_vol = BoxVolume(box1_tris, box1_center); |
| 48 | + |
| 49 | + // Iterate through the M boxes in boxes2 |
| 50 | + for (int m = 0; m < M; ++m) { |
| 51 | + // Repeat above steps for box2 |
| 52 | + // TODO: check if caching these value helps performance. |
| 53 | + const auto& box2 = boxes2_a[m]; |
| 54 | + const face_verts box2_tris = GetBoxTris(box2); |
| 55 | + const vec3<float> box2_center = BoxCenter(boxes2[m]); |
| 56 | + const face_verts box2_planes = GetBoxPlanes(box2); |
| 57 | + const float box2_vol = BoxVolume(box2_tris, box2_center); |
| 58 | + |
| 59 | + // Every triangle in one box will be compared to each plane in the other |
| 60 | + // box. There are 3 possible outcomes: |
| 61 | + // 1. If the triangle is fully inside, then it will |
| 62 | + // remain as is. |
| 63 | + // 2. If the triagnle it is fully outside, it will be removed. |
| 64 | + // 3. If the triangle intersects with the (infinite) plane, it |
| 65 | + // will be broken into subtriangles such that each subtriangle is full |
| 66 | + // inside the plane and part of the intersecting tetrahedron. |
| 67 | + |
| 68 | + // Tris in Box1 -> Planes in Box2 |
| 69 | + face_verts box1_intersect = |
| 70 | + BoxIntersections(box1_tris, box2_planes, box2_center); |
| 71 | + // Tris in Box2 -> Planes in Box1 |
| 72 | + face_verts box2_intersect = |
| 73 | + BoxIntersections(box2_tris, box1_planes, box1_center); |
| 74 | + |
| 75 | + // If there are overlapping regions in Box2, remove any coplanar faces |
| 76 | + if (box2_intersect.size() > 0) { |
| 77 | + // Identify if any triangles in Box2 are coplanar with Box1 |
| 78 | + std::vector<int> tri2_keep(box2_intersect.size()); |
| 79 | + std::fill(tri2_keep.begin(), tri2_keep.end(), 1); |
| 80 | + for (int b1 = 0; b1 < box1_intersect.size(); ++b1) { |
| 81 | + for (int b2 = 0; b2 < box2_intersect.size(); ++b2) { |
| 82 | + bool is_coplanar = |
| 83 | + IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]); |
| 84 | + if (is_coplanar) { |
| 85 | + tri2_keep[b2] = 0; |
| 86 | + } |
| 87 | + } |
| 88 | + } |
| 89 | + |
| 90 | + // Keep only the non coplanar triangles in Box2 - add them to the |
| 91 | + // Box1 triangles. |
| 92 | + for (int b2 = 0; b2 < box2_intersect.size(); ++b2) { |
| 93 | + if (tri2_keep[b2] == 1) { |
| 94 | + box1_intersect.push_back((box2_intersect[b2])); |
| 95 | + } |
| 96 | + } |
| 97 | + } |
| 98 | + |
| 99 | + // Initialize the vol and iou to 0.0 in case there are no triangles |
| 100 | + // in the intersecting shape. |
| 101 | + float vol = 0.0; |
| 102 | + float iou = 0.0; |
| 103 | + |
| 104 | + // If there are triangles in the intersecting shape |
| 105 | + if (box1_intersect.size() > 0) { |
| 106 | + // The intersecting shape is a polyhedron made up of the |
| 107 | + // triangular faces that are all now in box1_intersect. |
| 108 | + // Calculate the polyhedron center |
| 109 | + const vec3<float> polyhedron_center = PolyhedronCenter(box1_intersect); |
| 110 | + // Compute intersecting polyhedron volume |
| 111 | + vol = BoxVolume(box1_intersect, polyhedron_center); |
| 112 | + // Compute IoU |
| 113 | + iou = vol / (box1_vol + box2_vol - vol); |
| 114 | + } |
| 115 | + // Save out volume and IoU |
| 116 | + vols_a[n][m] = vol; |
| 117 | + ious_a[n][m] = iou; |
| 118 | + } |
| 119 | + } |
| 120 | + return std::make_tuple(vols, ious); |
| 121 | +} |
0 commit comments