26
26
from PIL import Image
27
27
28
28
29
- def init_cube_point_cloud (
30
- batch_size : int = 10 , n_points : int = 100000 , rotate_y : bool = True
31
- ):
29
+ def init_cube_point_cloud (batch_size : int , n_points : int , device : str , rotate_y : bool ):
32
30
"""
33
31
Generate a random point cloud of `n_points` whose points
34
32
are sampled from faces of a 3D cube.
35
33
"""
36
34
37
35
# create the cube mesh batch_size times
38
- meshes = TestPointsToVolumes .init_cube_mesh (batch_size )
36
+ meshes = TestPointsToVolumes .init_cube_mesh (batch_size = batch_size , device = device )
39
37
40
38
# generate point clouds by sampling points from the meshes
41
39
pcl = sample_points_from_meshes (meshes , num_samples = n_points , return_normals = False )
@@ -66,7 +64,7 @@ def init_cube_point_cloud(
66
64
67
65
if rotate_y :
68
66
# uniformly spaced rotations around y axis
69
- R = init_uniform_y_rotations (batch_size = batch_size )
67
+ R = init_uniform_y_rotations (batch_size = batch_size , device = device )
70
68
# rotate the point clouds around y axis
71
69
pcl = torch .bmm (pcl - 0.5 , R ) + 0.5
72
70
@@ -78,6 +76,7 @@ def init_volume_boundary_pointcloud(
78
76
volume_size : Tuple [int , int , int ],
79
77
n_points : int ,
80
78
interp_mode : str ,
79
+ device : str ,
81
80
require_grad : bool = False ,
82
81
):
83
82
"""
@@ -86,7 +85,9 @@ def init_volume_boundary_pointcloud(
86
85
"""
87
86
88
87
# generate a 3D point cloud sampled from sides of a [0,1] cube
89
- xyz , rgb = init_cube_point_cloud (batch_size , n_points = n_points , rotate_y = True )
88
+ xyz , rgb = init_cube_point_cloud (
89
+ batch_size , n_points = n_points , device = device , rotate_y = True
90
+ )
90
91
91
92
# make volume_size tensor
92
93
volume_size_t = torch .tensor (volume_size , dtype = xyz .dtype , device = xyz .device )
@@ -128,12 +129,11 @@ def init_volume_boundary_pointcloud(
128
129
return pointclouds , initial_volumes
129
130
130
131
131
- def init_uniform_y_rotations (batch_size : int = 10 ):
132
+ def init_uniform_y_rotations (batch_size : int , device : torch . device ):
132
133
"""
133
134
Generate a batch of `batch_size` 3x3 rotation matrices around y-axis
134
135
whose angles are uniformly distributed between 0 and 2 pi.
135
136
"""
136
- device = torch .device ("cuda:0" )
137
137
axis = torch .tensor ([0.0 , 1.0 , 0.0 ], device = device , dtype = torch .float32 )
138
138
angles = torch .linspace (0 , 2.0 * np .pi , batch_size + 1 , device = device )
139
139
angles = angles [:batch_size ]
@@ -153,17 +153,22 @@ def add_points_to_volumes(
153
153
volume_size : Tuple [int , int , int ],
154
154
n_points : int ,
155
155
interp_mode : str ,
156
+ device : str ,
156
157
):
157
158
(pointclouds , initial_volumes ) = init_volume_boundary_pointcloud (
158
159
batch_size = batch_size ,
159
160
volume_size = volume_size ,
160
161
n_points = n_points ,
161
162
interp_mode = interp_mode ,
162
163
require_grad = False ,
164
+ device = device ,
163
165
)
164
166
167
+ torch .cuda .synchronize ()
168
+
165
169
def _add_points_to_volumes ():
166
170
add_pointclouds_to_volumes (pointclouds , initial_volumes , mode = interp_mode )
171
+ torch .cuda .synchronize ()
167
172
168
173
return _add_points_to_volumes
169
174
@@ -179,12 +184,12 @@ def stack_4d_tensor_to_3d(arr):
179
184
return arr3d
180
185
181
186
@staticmethod
182
- def init_cube_mesh (batch_size : int = 10 ):
187
+ def init_cube_mesh (batch_size : int , device : str ):
183
188
"""
184
189
Generate a batch of `batch_size` cube meshes.
185
190
"""
186
191
187
- device = torch .device ("cuda:0" )
192
+ device = torch .device (device )
188
193
189
194
verts , faces = [], []
190
195
@@ -255,6 +260,7 @@ def test_from_point_cloud(self, interp_mode="trilinear"):
255
260
interp_mode = interp_mode ,
256
261
batch_size = batch_size ,
257
262
require_grad = True ,
263
+ device = "cuda:0" ,
258
264
)
259
265
260
266
volumes = add_pointclouds_to_volumes (
0 commit comments