@@ -157,12 +157,15 @@ def add_points_features_to_volume_densities_features(
157
157
of its floating point coordinate. The weights are
158
158
determined using a trilinear interpolation scheme.
159
159
Trilinear splatting is fully differentiable w.r.t. all input arguments.
160
- mask: A binary mask of shape `(minibatch, N)` determining which 3D points
161
- are going to be converted to the resulting volume.
162
- Set to `None` if all points are valid.
163
160
min_weight: A scalar controlling the lowest possible total per-voxel
164
161
weight used to normalize the features accumulated in a voxel.
165
162
Only active for `mode==trilinear`.
163
+ mask: A binary mask of shape `(minibatch, N)` determining which 3D points
164
+ are going to be converted to the resulting volume.
165
+ Set to `None` if all points are valid.
166
+ grid_sizes: `LongTensor` of shape (minibatch, 3) representing the
167
+ spatial resolutions of each of the the non-flattened `volumes` tensors,
168
+ or None to indicate the whole volume is used for every batch element.
166
169
Returns:
167
170
volume_features: Output volume of shape `(minibatch, feature_dim, D, H, W)`
168
171
volume_densities: Occupancy volume of shape `(minibatch, 1, D, H, W)`
@@ -284,13 +287,15 @@ def splat_points_to_volumes(
284
287
volume_features: Batch of input *flattened* feature volumes
285
288
of shape `(minibatch, feature_dim, N_voxels)`
286
289
volume_densities: Batch of input *flattened* feature volume densities
287
- of shape `(minibatch, 1, N_voxels )`. Each voxel should
290
+ of shape `(minibatch, N_voxels, 1 )`. Each voxel should
288
291
contain a non-negative number corresponding to its
289
292
opaqueness (the higher, the less transparent).
290
293
grid_sizes: `LongTensor` of shape (minibatch, 3) representing the
291
294
spatial resolutions of each of the the non-flattened `volumes` tensors.
292
295
Note that the following has to hold:
293
296
`torch.prod(grid_sizes, dim=1)==N_voxels`
297
+ min_weight: A scalar controlling the lowest possible total per-voxel
298
+ weight used to normalize the features accumulated in a voxel.
294
299
mask: A binary mask of shape `(minibatch, N)` determining which 3D points
295
300
are going to be converted to the resulting volume.
296
301
Set to `None` if all points are valid.
@@ -457,9 +462,6 @@ def round_points_to_volumes(
457
462
# split into separate coordinate vectors
458
463
X , Y , Z = XYZ .split (1 , dim = 2 )
459
464
460
- # get random indices for the purpose of adding out-of-bounds values
461
- rand_idx = X .new_zeros (X .shape ).random_ (0 , n_voxels )
462
-
463
465
# valid - binary indicators of votes that fall into the volume
464
466
grid_sizes = grid_sizes .type_as (XYZ )
465
467
valid = (
@@ -470,6 +472,8 @@ def round_points_to_volumes(
470
472
* (0 <= Z )
471
473
* (Z < grid_sizes_xyz [:, None , 2 :3 ])
472
474
).long ()
475
+ if mask is not None :
476
+ valid = valid * mask [:, :, None ].long ()
473
477
474
478
# get random indices for the purpose of adding out-of-bounds values
475
479
rand_idx = valid .new_zeros (X .shape ).random_ (0 , n_voxels )
0 commit comments