Skip to content

Commit f25af96

Browse files
bottlerfacebook-github-bot
authored andcommitted
vert_align for Pointclouds object
Reviewed By: gkioxari Differential Revision: D21088730 fbshipit-source-id: f8c125ac8c8009d45712ae63237ca64acf1faf45
1 parent e19df58 commit f25af96

File tree

2 files changed

+61
-20
lines changed

2 files changed

+61
-20
lines changed

pytorch3d/ops/vert_align.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ def vert_align(
2525
feats: FloatTensor of shape (N, C, H, W) representing image features
2626
from which to sample or a list of features each with potentially
2727
different C, H or W dimensions.
28-
verts: FloatTensor of shape (N, V, 3) or an object (e.g. Meshes) with
29-
'verts_padded' as an attribute giving the (x, y, z) vertex positions
30-
for which to sample. (x, y) verts should be normalized such that
31-
(-1, -1) corresponds to top-left and (+1, +1) to bottom-right
28+
verts: FloatTensor of shape (N, V, 3) or an object (e.g. Meshes or Pointclouds)
29+
with `verts_padded' or `points_padded' as an attribute giving the (x, y, z)
30+
vertex positions for which to sample. (x, y) verts should be normalized such
31+
that (-1, -1) corresponds to top-left and (+1, +1) to bottom-right
3232
location in the input feature map.
3333
return_packed: (bool) Indicates whether to return packed features
3434
interp_mode: (str) Specifies how to interpolate features.
@@ -44,22 +44,25 @@ def vert_align(
4444
resolution agnostic. Default: ``True``
4545
4646
Returns:
47-
feats_sampled: FloatTensor of shape (N, V, C) giving sampled features for
48-
each vertex. If feats is a list, we return concatentated
49-
features in axis=2 of shape (N, V, sum(C_n)) where
50-
C_n = feats[n].shape[1]. If return_packed = True, the
51-
features are transformed to a packed representation
52-
of shape (sum(V), C)
53-
47+
feats_sampled: FloatTensor of shape (N, V, C) giving sampled features for each
48+
vertex. If feats is a list, we return concatentated features in axis=2 of
49+
shape (N, V, sum(C_n)) where C_n = feats[n].shape[1].
50+
If return_packed = True, the features are transformed to a packed
51+
representation of shape (sum(V), C)
5452
"""
5553
if torch.is_tensor(verts):
5654
if verts.dim() != 3:
5755
raise ValueError("verts tensor should be 3 dimensional")
5856
grid = verts
5957
elif hasattr(verts, "verts_padded"):
6058
grid = verts.verts_padded()
59+
elif hasattr(verts, "points_padded"):
60+
grid = verts.points_padded()
6161
else:
62-
raise ValueError("verts must be a tensor or have a `verts_padded` attribute")
62+
raise ValueError(
63+
"verts must be a tensor or have a "
64+
+ "`points_padded' or`verts_padded` attribute."
65+
)
6366

6467
grid = grid[:, None, :, :2] # (N, 1, V, 2)
6568

tests/test_vert_align.py

+46-8
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
from common_testing import TestCaseMixin
99
from pytorch3d.ops.vert_align import vert_align
1010
from pytorch3d.structures.meshes import Meshes
11+
from pytorch3d.structures.pointclouds import Pointclouds
1112

1213

1314
class TestVertAlign(TestCaseMixin, unittest.TestCase):
1415
@staticmethod
1516
def vert_align_naive(
16-
feats, verts_or_meshes, return_packed: bool = False, align_corners: bool = True
17+
feats, verts, return_packed: bool = False, align_corners: bool = True
1718
):
1819
"""
1920
Naive implementation of vert_align.
@@ -28,12 +29,12 @@ def vert_align_naive(
2829
out_i_feats = []
2930
for feat in feats:
3031
feats_i = feat[i][None, :, :, :] # (1, C, H, W)
31-
if torch.is_tensor(verts_or_meshes):
32-
grid = verts_or_meshes[i][None, None, :, :2] # (1, 1, V, 2)
33-
elif hasattr(verts_or_meshes, "verts_list"):
34-
grid = verts_or_meshes.verts_list()[i][
35-
None, None, :, :2
36-
] # (1, 1, V, 2)
32+
if torch.is_tensor(verts):
33+
grid = verts[i][None, None, :, :2] # (1, 1, V, 2)
34+
elif hasattr(verts, "verts_list"):
35+
grid = verts.verts_list()[i][None, None, :, :2] # (1, 1, V, 2)
36+
elif hasattr(verts, "points_list"):
37+
grid = verts.points_list()[i][None, None, :, :2] # (1, 1, V, 2)
3738
else:
3839
raise ValueError("verts_or_meshes is invalid")
3940
feat_sampled_i = F.grid_sample(
@@ -56,7 +57,9 @@ def vert_align_naive(
5657
return out_feats
5758

5859
@staticmethod
59-
def init_meshes(num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 3000):
60+
def init_meshes(
61+
num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 3000
62+
) -> Meshes:
6063
device = torch.device("cuda:0")
6164
verts_list = []
6265
faces_list = []
@@ -74,6 +77,20 @@ def init_meshes(num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 30
7477

7578
return meshes
7679

80+
@staticmethod
81+
def init_pointclouds(num_clouds: int = 10, num_points: int = 1000) -> Pointclouds:
82+
device = torch.device("cuda:0")
83+
points_list = []
84+
for _ in range(num_clouds):
85+
points = (
86+
torch.rand((num_points, 3), dtype=torch.float32, device=device) * 2.0
87+
- 1.0
88+
) # points in the space of [-1, 1]
89+
points_list.append(points)
90+
pointclouds = Pointclouds(points=points_list)
91+
92+
return pointclouds
93+
7794
@staticmethod
7895
def init_feats(batch_size: int = 10, num_channels: int = 256, device: str = "cuda"):
7996
H, W = [14, 28], [14, 28]
@@ -99,6 +116,27 @@ def test_vert_align_with_meshes(self):
99116
naive_out = TestVertAlign.vert_align_naive(feats[0], meshes, return_packed=True)
100117
self.assertClose(out, naive_out)
101118

119+
def test_vert_align_with_pointclouds(self):
120+
"""
121+
Test vert align vs naive implementation with meshes.
122+
"""
123+
pointclouds = TestVertAlign.init_pointclouds(10, 1000)
124+
feats = TestVertAlign.init_feats(10, 256)
125+
126+
# feats in list
127+
out = vert_align(feats, pointclouds, return_packed=True)
128+
naive_out = TestVertAlign.vert_align_naive(
129+
feats, pointclouds, return_packed=True
130+
)
131+
self.assertClose(out, naive_out)
132+
133+
# feats as tensor
134+
out = vert_align(feats[0], pointclouds, return_packed=True)
135+
naive_out = TestVertAlign.vert_align_naive(
136+
feats[0], pointclouds, return_packed=True
137+
)
138+
self.assertClose(out, naive_out)
139+
102140
def test_vert_align_with_verts(self):
103141
"""
104142
Test vert align vs naive implementation with verts as tensor.

0 commit comments

Comments
 (0)