@@ -261,19 +261,28 @@ class VolumeSampler(torch.nn.Module):
261
261
at 3D points sampled along projection rays.
262
262
"""
263
263
264
- def __init__ (self , volumes : Volumes , sample_mode : str = "bilinear" ) -> None :
264
+ def __init__ (
265
+ self ,
266
+ volumes : Volumes ,
267
+ sample_mode : str = "bilinear" ,
268
+ padding_mode : str = "zeros" ,
269
+ ) -> None :
265
270
"""
266
271
Args:
267
272
volumes: An instance of the `Volumes` class representing a
268
273
batch of volumes that are being rendered.
269
274
sample_mode: Defines the algorithm used to sample the volumetric
270
275
voxel grid. Can be either "bilinear" or "nearest".
276
+ padding_mode: How to handle values outside of the volume.
277
+ One of: zeros, border, reflection
278
+ See torch.nn.functional.grid_sample for more information.
271
279
"""
272
280
super ().__init__ ()
273
281
if not isinstance (volumes , Volumes ):
274
282
raise ValueError ("'volumes' have to be an instance of the 'Volumes' class." )
275
283
self ._volumes = volumes
276
284
self ._sample_mode = sample_mode
285
+ self ._padding_mode = padding_mode
277
286
278
287
def _get_ray_directions_transform (self ):
279
288
"""
@@ -375,6 +384,7 @@ def forward(
375
384
rays_points_local_flat ,
376
385
align_corners = True ,
377
386
mode = self ._sample_mode ,
387
+ padding_mode = self ._padding_mode ,
378
388
)
379
389
380
390
# permute the dimensions & reshape densities after sampling
@@ -392,6 +402,7 @@ def forward(
392
402
rays_points_local_flat ,
393
403
align_corners = True ,
394
404
mode = self ._sample_mode ,
405
+ padding_mode = self ._padding_mode ,
395
406
)
396
407
397
408
# permute the dimensions & reshape features after sampling
0 commit comments