Skip to content

Commit 305cf32

Browse files
bottlerfacebook-github-bot
authored andcommitted
Avoid raysampler dict
Summary: A significant speedup (e.g. >2% of a forward pass). Move NDCMultinomialRaysampler parts of AbstractMaskRaySampler to members instead of living in a dict. The dict was hiding them from the nn.Module system so their _xy_grid members were remaining on the CPU. Therefore they were being copied to the GPU in every forward pass. (We couldn't easily use a ModuleDict here because the enum keys are not strs.) Reviewed By: shapovalov Differential Revision: D39668589 fbshipit-source-id: 719b88e4a08fd7263a284e0ab38189e666bd7e3a
1 parent da7fe28 commit 305cf32

File tree

1 file changed

+32
-32
lines changed

1 file changed

+32
-32
lines changed

pytorch3d/implicitron/models/renderer/ray_sampler.py

+32-32
Original file line numberDiff line numberDiff line change
@@ -100,34 +100,32 @@ def __post_init__(self):
100100
),
101101
}
102102

103-
self._raysamplers = {
104-
EvaluationMode.TRAINING: NDCMultinomialRaysampler(
105-
image_width=self.image_width,
106-
image_height=self.image_height,
107-
n_pts_per_ray=self.n_pts_per_ray_training,
108-
min_depth=0.0,
109-
max_depth=0.0,
110-
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
111-
if self._sampling_mode[EvaluationMode.TRAINING]
112-
== RenderSamplingMode.MASK_SAMPLE
113-
else None,
114-
unit_directions=True,
115-
stratified_sampling=self.stratified_point_sampling_training,
116-
),
117-
EvaluationMode.EVALUATION: NDCMultinomialRaysampler(
118-
image_width=self.image_width,
119-
image_height=self.image_height,
120-
n_pts_per_ray=self.n_pts_per_ray_evaluation,
121-
min_depth=0.0,
122-
max_depth=0.0,
123-
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
124-
if self._sampling_mode[EvaluationMode.EVALUATION]
125-
== RenderSamplingMode.MASK_SAMPLE
126-
else None,
127-
unit_directions=True,
128-
stratified_sampling=self.stratified_point_sampling_evaluation,
129-
),
130-
}
103+
self._training_raysampler = NDCMultinomialRaysampler(
104+
image_width=self.image_width,
105+
image_height=self.image_height,
106+
n_pts_per_ray=self.n_pts_per_ray_training,
107+
min_depth=0.0,
108+
max_depth=0.0,
109+
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
110+
if self._sampling_mode[EvaluationMode.TRAINING]
111+
== RenderSamplingMode.MASK_SAMPLE
112+
else None,
113+
unit_directions=True,
114+
stratified_sampling=self.stratified_point_sampling_training,
115+
)
116+
self._evaluation_raysampler = NDCMultinomialRaysampler(
117+
image_width=self.image_width,
118+
image_height=self.image_height,
119+
n_pts_per_ray=self.n_pts_per_ray_evaluation,
120+
min_depth=0.0,
121+
max_depth=0.0,
122+
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
123+
if self._sampling_mode[EvaluationMode.EVALUATION]
124+
== RenderSamplingMode.MASK_SAMPLE
125+
else None,
126+
unit_directions=True,
127+
stratified_sampling=self.stratified_point_sampling_evaluation,
128+
)
131129

132130
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
133131
raise NotImplementedError()
@@ -169,11 +167,13 @@ def forward(
169167

170168
min_depth, max_depth = self._get_min_max_depth_bounds(cameras)
171169

170+
raysampler = {
171+
EvaluationMode.TRAINING: self._training_raysampler,
172+
EvaluationMode.EVALUATION: self._evaluation_raysampler,
173+
}[evaluation_mode]
174+
172175
# pyre-fixme[29]:
173-
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
174-
# torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
175-
# torch.Tensor, torch.nn.Module]` is not a function.
176-
ray_bundle = self._raysamplers[evaluation_mode](
176+
ray_bundle = raysampler(
177177
cameras=cameras,
178178
mask=sample_mask,
179179
min_depth=min_depth,

0 commit comments

Comments
 (0)