Skip to content

Commit 097b0ef

Browse files
bottlerfacebook-github-bot
authored andcommitted
use no_grad for sample_pdf in NeRF project
Summary: We don't use gradents of sample_pdf. Here we disable gradient calculation around calling it, instead of calling detach later. There's a theoretical speedup, but mainly this enables using sample_pdf implementations which don't support gradients. Reviewed By: nikhilaravi Differential Revision: D28057284 fbshipit-source-id: 8a9d5e73f18b34e1e4291028008e02973023638d
1 parent 6053d0e commit 097b0ef

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

projects/nerf/nerf/raysampler.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -69,22 +69,19 @@ def forward(
6969
# Calculate the mid-points between the ray depths.
7070
z_vals = input_ray_bundle.lengths
7171
batch_size = z_vals.shape[0]
72-
z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
7372

7473
# Carry out the importance sampling.
75-
z_samples = (
76-
sample_pdf(
74+
with torch.no_grad():
75+
z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
76+
z_samples = sample_pdf(
7777
z_vals_mid.view(-1, z_vals_mid.shape[-1]),
7878
ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
7979
self._n_pts_per_ray,
8080
det=not (
8181
(self._stratified and self.training)
8282
or (self._stratified_test and not self.training)
8383
),
84-
)
85-
.detach()
86-
.view(batch_size, z_vals.shape[1], self._n_pts_per_ray)
87-
)
84+
).view(batch_size, z_vals.shape[1], self._n_pts_per_ray)
8885

8986
if self._add_input_samples:
9087
# Add the new samples to the input ones.

0 commit comments

Comments
 (0)