8
8
from common_testing import TestCaseMixin
9
9
from pytorch3d .ops .vert_align import vert_align
10
10
from pytorch3d .structures .meshes import Meshes
11
+ from pytorch3d .structures .pointclouds import Pointclouds
11
12
12
13
13
14
class TestVertAlign (TestCaseMixin , unittest .TestCase ):
14
15
@staticmethod
15
16
def vert_align_naive (
16
- feats , verts_or_meshes , return_packed : bool = False , align_corners : bool = True
17
+ feats , verts , return_packed : bool = False , align_corners : bool = True
17
18
):
18
19
"""
19
20
Naive implementation of vert_align.
@@ -28,12 +29,12 @@ def vert_align_naive(
28
29
out_i_feats = []
29
30
for feat in feats :
30
31
feats_i = feat [i ][None , :, :, :] # (1, C, H, W)
31
- if torch .is_tensor (verts_or_meshes ):
32
- grid = verts_or_meshes [i ][None , None , :, :2 ] # (1, 1, V, 2)
33
- elif hasattr (verts_or_meshes , "verts_list" ):
34
- grid = verts_or_meshes .verts_list ()[i ][
35
- None , None , :, : 2
36
- ] # (1, 1, V, 2)
32
+ if torch .is_tensor (verts ):
33
+ grid = verts [i ][None , None , :, :2 ] # (1, 1, V, 2)
34
+ elif hasattr (verts , "verts_list" ):
35
+ grid = verts .verts_list ()[i ][None , None , :, : 2 ] # (1, 1, V, 2)
36
+ elif hasattr ( verts , "points_list" ):
37
+ grid = verts . points_list ()[ i ][ None , None , :, : 2 ] # (1, 1, V, 2)
37
38
else :
38
39
raise ValueError ("verts_or_meshes is invalid" )
39
40
feat_sampled_i = F .grid_sample (
@@ -56,7 +57,9 @@ def vert_align_naive(
56
57
return out_feats
57
58
58
59
@staticmethod
59
- def init_meshes (num_meshes : int = 10 , num_verts : int = 1000 , num_faces : int = 3000 ):
60
+ def init_meshes (
61
+ num_meshes : int = 10 , num_verts : int = 1000 , num_faces : int = 3000
62
+ ) -> Meshes :
60
63
device = torch .device ("cuda:0" )
61
64
verts_list = []
62
65
faces_list = []
@@ -74,6 +77,20 @@ def init_meshes(num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 30
74
77
75
78
return meshes
76
79
80
+ @staticmethod
81
+ def init_pointclouds (num_clouds : int = 10 , num_points : int = 1000 ) -> Pointclouds :
82
+ device = torch .device ("cuda:0" )
83
+ points_list = []
84
+ for _ in range (num_clouds ):
85
+ points = (
86
+ torch .rand ((num_points , 3 ), dtype = torch .float32 , device = device ) * 2.0
87
+ - 1.0
88
+ ) # points in the space of [-1, 1]
89
+ points_list .append (points )
90
+ pointclouds = Pointclouds (points = points_list )
91
+
92
+ return pointclouds
93
+
77
94
@staticmethod
78
95
def init_feats (batch_size : int = 10 , num_channels : int = 256 , device : str = "cuda" ):
79
96
H , W = [14 , 28 ], [14 , 28 ]
@@ -99,6 +116,27 @@ def test_vert_align_with_meshes(self):
99
116
naive_out = TestVertAlign .vert_align_naive (feats [0 ], meshes , return_packed = True )
100
117
self .assertClose (out , naive_out )
101
118
119
+ def test_vert_align_with_pointclouds (self ):
120
+ """
121
+ Test vert align vs naive implementation with meshes.
122
+ """
123
+ pointclouds = TestVertAlign .init_pointclouds (10 , 1000 )
124
+ feats = TestVertAlign .init_feats (10 , 256 )
125
+
126
+ # feats in list
127
+ out = vert_align (feats , pointclouds , return_packed = True )
128
+ naive_out = TestVertAlign .vert_align_naive (
129
+ feats , pointclouds , return_packed = True
130
+ )
131
+ self .assertClose (out , naive_out )
132
+
133
+ # feats as tensor
134
+ out = vert_align (feats [0 ], pointclouds , return_packed = True )
135
+ naive_out = TestVertAlign .vert_align_naive (
136
+ feats [0 ], pointclouds , return_packed = True
137
+ )
138
+ self .assertClose (out , naive_out )
139
+
102
140
def test_vert_align_with_verts (self ):
103
141
"""
104
142
Test vert align vs naive implementation with verts as tensor.
0 commit comments