|
8 | 8 | This file contains classes that implement Voxel grids, both in their full resolution
|
9 | 9 | as in the factorized form. There are two factorized forms implemented, Tensor rank decomposition
|
10 | 10 | or CANDECOMP/PARAFAC (here CP) and Vector Matrix (here VM) factorization from the
|
11 |
| -https://arxiv.org/abs/2203.09517. |
| 11 | +TensoRF (https://arxiv.org/abs/2203.09517) paper. |
| 12 | +
|
| 13 | +In addition, the module VoxelGridModule implements a trainable instance of one of |
| 14 | +these classes. |
| 15 | +
|
12 | 16 | """
|
13 | 17 |
|
14 | 18 | from dataclasses import dataclass
|
15 | 19 | from typing import ClassVar, Dict, Optional, Tuple, Type
|
16 | 20 |
|
17 | 21 | import torch
|
18 |
| -from pytorch3d.implicitron.tools.config import registry, ReplaceableBase |
| 22 | +from pytorch3d.implicitron.tools.config import ( |
| 23 | + Configurable, |
| 24 | + registry, |
| 25 | + ReplaceableBase, |
| 26 | + run_auto_creation, |
| 27 | +) |
19 | 28 | from pytorch3d.structures.volumes import VolumeLocator
|
20 | 29 |
|
21 | 30 | from .utils import interpolate_line, interpolate_plane, interpolate_volume
|
@@ -426,3 +435,78 @@ def get_shapes(self) -> Dict[str, Tuple]:
|
426 | 435 | )
|
427 | 436 |
|
428 | 437 | return shape_dict
|
| 438 | + |
| 439 | + |
| 440 | +class VoxelGridModule(Configurable, torch.nn.Module): |
| 441 | + """ |
| 442 | + A wrapper torch.nn.Module for the VoxelGrid classes, which |
| 443 | + contains parameters that are needed to train the VoxelGrid classes. |
| 444 | +
|
| 445 | + Members: |
| 446 | + voxel_grid_class_type: The name of the class to use for voxel_grid, |
| 447 | + which must be available in the registry. Default FullResolutionVoxelGrid. |
| 448 | + voxel_grid: An instance of `VoxelGridBase`. This is the object which |
| 449 | + this class wraps. |
| 450 | + extents: 3-tuple of a form (width, height, depth), denotes the size of the grid |
| 451 | + in world units. |
| 452 | + translation: 3-tuple of float. The center of the volume in world units as (x, y, z). |
| 453 | + init_std: Parameters are initialized using the gaussian distribution |
| 454 | + with mean=init_mean and std=init_std. Default 0.1 |
| 455 | + init_mean: Parameters are initialized using the gaussian distribution |
| 456 | + with mean=init_mean and std=init_std. Default 0. |
| 457 | + """ |
| 458 | + |
| 459 | + voxel_grid_class_type: str = "FullResolutionVoxelGrid" |
| 460 | + voxel_grid: VoxelGridBase |
| 461 | + |
| 462 | + extents: Tuple[float, float, float] = 1.0 |
| 463 | + translation: Tuple[float, float, float] = (0.0, 0.0, 0.0) |
| 464 | + |
| 465 | + init_std: float = 0.1 |
| 466 | + init_mean: float = 0 |
| 467 | + |
| 468 | + def __post_init__(self): |
| 469 | + super().__init__() |
| 470 | + run_auto_creation(self) |
| 471 | + n_grids = 1 # Voxel grid objects are batched. We need only a single grid. |
| 472 | + shapes = self.voxel_grid.get_shapes() |
| 473 | + params = { |
| 474 | + name: torch.normal( |
| 475 | + mean=torch.zeros((n_grids, *shape)) + self.init_mean, |
| 476 | + std=self.init_std, |
| 477 | + ) |
| 478 | + for name, shape in shapes.items() |
| 479 | + } |
| 480 | + self.params = torch.nn.ParameterDict(params) |
| 481 | + |
| 482 | + def forward(self, points: torch.Tensor) -> torch.Tensor: |
| 483 | + """ |
| 484 | + Evaluates points in the world coordinate frame on the voxel_grid. |
| 485 | +
|
| 486 | + Args: |
| 487 | + points (torch.Tensor): tensor of points that you want to query |
| 488 | + of a form (n_points, 3) |
| 489 | + Returns: |
| 490 | + torch.Tensor of shape (n_points, n_features) |
| 491 | + """ |
| 492 | + locator = VolumeLocator( |
| 493 | + batch_size=1, |
| 494 | + # The resolution of the voxel grid does not need to be known |
| 495 | + # to the locator object. It is easiest to fix the resolution of the locator. |
| 496 | + # In particular we fix it to (2,2,2) so that there is exactly one voxel of the |
| 497 | + # desired size. The locator object uses (z, y, x) convention for the grid_size, |
| 498 | + # and this module uses (x, y, z) convention so the order has to be reversed |
| 499 | + # (irrelevant in this case since they are all equal). |
| 500 | + # It is (2, 2, 2) because the VolumeLocator object behaves like |
| 501 | + # align_corners=True, which means that the points are in the corners of |
| 502 | + # the volume. So in the grid of (2, 2, 2) there is only one voxel. |
| 503 | + grid_sizes=(2, 2, 2), |
| 504 | + # The locator object uses (x, y, z) convention for the |
| 505 | + # voxel size and translation. |
| 506 | + voxel_size=self.extents, |
| 507 | + volume_translation=self.translation, |
| 508 | + device=next(self.params.values()).device, |
| 509 | + ) |
| 510 | + grid_values = self.voxel_grid.values_type(**self.params) |
| 511 | + # voxel grids operate with extra n_grids dimension, which we fix to one |
| 512 | + return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0] |
0 commit comments