Skip to content

Commit 24f5f4a

Browse files
Darijan Gudeljfacebook-github-bot
Darijan Gudelj
authored andcommitted
VoxelGridModule
Summary: Simple wrapper around voxel grids to make them a module Reviewed By: bottler Differential Revision: D38829762 fbshipit-source-id: dfee85088fa3c65e396cc7d3bf7ebaaffaadb646
1 parent 6653f44 commit 24f5f4a

File tree

2 files changed

+112
-2
lines changed

2 files changed

+112
-2
lines changed

pytorch3d/implicitron/models/implicit_function/voxel_grid.py

+86-2
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,23 @@
88
This file contains classes that implement Voxel grids, both in their full resolution
99
as in the factorized form. There are two factorized forms implemented, Tensor rank decomposition
1010
or CANDECOMP/PARAFAC (here CP) and Vector Matrix (here VM) factorization from the
11-
https://arxiv.org/abs/2203.09517.
11+
TensoRF (https://arxiv.org/abs/2203.09517) paper.
12+
13+
In addition, the module VoxelGridModule implements a trainable instance of one of
14+
these classes.
15+
1216
"""
1317

1418
from dataclasses import dataclass
1519
from typing import ClassVar, Dict, Optional, Tuple, Type
1620

1721
import torch
18-
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
22+
from pytorch3d.implicitron.tools.config import (
23+
Configurable,
24+
registry,
25+
ReplaceableBase,
26+
run_auto_creation,
27+
)
1928
from pytorch3d.structures.volumes import VolumeLocator
2029

2130
from .utils import interpolate_line, interpolate_plane, interpolate_volume
@@ -426,3 +435,78 @@ def get_shapes(self) -> Dict[str, Tuple]:
426435
)
427436

428437
return shape_dict
438+
439+
440+
class VoxelGridModule(Configurable, torch.nn.Module):
441+
"""
442+
A wrapper torch.nn.Module for the VoxelGrid classes, which
443+
contains parameters that are needed to train the VoxelGrid classes.
444+
445+
Members:
446+
voxel_grid_class_type: The name of the class to use for voxel_grid,
447+
which must be available in the registry. Default FullResolutionVoxelGrid.
448+
voxel_grid: An instance of `VoxelGridBase`. This is the object which
449+
this class wraps.
450+
extents: 3-tuple of a form (width, height, depth), denotes the size of the grid
451+
in world units.
452+
translation: 3-tuple of float. The center of the volume in world units as (x, y, z).
453+
init_std: Parameters are initialized using the gaussian distribution
454+
with mean=init_mean and std=init_std. Default 0.1
455+
init_mean: Parameters are initialized using the gaussian distribution
456+
with mean=init_mean and std=init_std. Default 0.
457+
"""
458+
459+
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
460+
voxel_grid: VoxelGridBase
461+
462+
extents: Tuple[float, float, float] = 1.0
463+
translation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
464+
465+
init_std: float = 0.1
466+
init_mean: float = 0
467+
468+
def __post_init__(self):
469+
super().__init__()
470+
run_auto_creation(self)
471+
n_grids = 1 # Voxel grid objects are batched. We need only a single grid.
472+
shapes = self.voxel_grid.get_shapes()
473+
params = {
474+
name: torch.normal(
475+
mean=torch.zeros((n_grids, *shape)) + self.init_mean,
476+
std=self.init_std,
477+
)
478+
for name, shape in shapes.items()
479+
}
480+
self.params = torch.nn.ParameterDict(params)
481+
482+
def forward(self, points: torch.Tensor) -> torch.Tensor:
483+
"""
484+
Evaluates points in the world coordinate frame on the voxel_grid.
485+
486+
Args:
487+
points (torch.Tensor): tensor of points that you want to query
488+
of a form (n_points, 3)
489+
Returns:
490+
torch.Tensor of shape (n_points, n_features)
491+
"""
492+
locator = VolumeLocator(
493+
batch_size=1,
494+
# The resolution of the voxel grid does not need to be known
495+
# to the locator object. It is easiest to fix the resolution of the locator.
496+
# In particular we fix it to (2,2,2) so that there is exactly one voxel of the
497+
# desired size. The locator object uses (z, y, x) convention for the grid_size,
498+
# and this module uses (x, y, z) convention so the order has to be reversed
499+
# (irrelevant in this case since they are all equal).
500+
# It is (2, 2, 2) because the VolumeLocator object behaves like
501+
# align_corners=True, which means that the points are in the corners of
502+
# the volume. So in the grid of (2, 2, 2) there is only one voxel.
503+
grid_sizes=(2, 2, 2),
504+
# The locator object uses (x, y, z) convention for the
505+
# voxel size and translation.
506+
voxel_size=self.extents,
507+
volume_translation=self.translation,
508+
device=next(self.params.values()).device,
509+
)
510+
grid_values = self.voxel_grid.values_type(**self.params)
511+
# voxel grids operate with extra n_grids dimension, which we fix to one
512+
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]

tests/implicitron/test_voxel_grids.py

+26
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
CPFactorizedVoxelGrid,
2020
FullResolutionVoxelGrid,
2121
VMFactorizedVoxelGrid,
22+
VoxelGridModule,
2223
)
2324

2425
from pytorch3d.implicitron.tools.config import expand_args_fields
@@ -198,6 +199,7 @@ def setUp(self):
198199
expand_args_fields(FullResolutionVoxelGrid)
199200
expand_args_fields(CPFactorizedVoxelGrid)
200201
expand_args_fields(VMFactorizedVoxelGrid)
202+
expand_args_fields(VoxelGridModule)
201203

202204
def _interpolate_1D(
203205
self, points: torch.Tensor, vectors: torch.Tensor
@@ -585,3 +587,27 @@ def test(cls, **kwargs):
585587
n_features=10,
586588
n_components=3,
587589
)
590+
591+
def test_voxel_grid_module_location(self, n_times=10):
592+
"""
593+
This checks the module uses locator correctly etc..
594+
595+
If we know that voxel grids work for (x, y, z) in local coordinates
596+
to test if the VoxelGridModule does not have permuted dimensions we
597+
create local coordinates, pass them through verified voxelgrids and
598+
compare the result with the result that we get when we convert
599+
coordinates to world and pass them through the VoxelGridModule
600+
"""
601+
for _ in range(n_times):
602+
extents = tuple(torch.randint(1, 50, size=(3,)).tolist())
603+
604+
grid = VoxelGridModule(extents=extents)
605+
local_point = torch.rand(1, 3) * 2 - 1
606+
world_point = local_point * torch.tensor(extents) / 2
607+
grid_values = grid.voxel_grid.values_type(**grid.params)
608+
609+
assert torch.allclose(
610+
grid(world_point)[0, 0],
611+
grid.voxel_grid.evaluate_local(local_point[None], grid_values)[0, 0, 0],
612+
rtol=0.0001,
613+
)

0 commit comments

Comments
 (0)