Skip to content

Commit bc8361f

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Lighting broadcasting bug fix
Summary: Fixed multiple issues with shape broadcasting in lighting, shading and blending and updated the tests. Reviewed By: bottler Differential Revision: D28997941 fbshipit-source-id: d3ef93f979344076b1d9098a86178b4da63844c8
1 parent 9de627e commit bc8361f

File tree

4 files changed

+73
-31
lines changed

4 files changed

+73
-31
lines changed

pytorch3d/renderer/blending.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

33

4-
from typing import NamedTuple, Sequence
4+
from typing import NamedTuple, Sequence, Union
55

66
import torch
77
from pytorch3d import _C # pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.
88

9-
109
# Example functions for blending the top K colors per pixel using the outputs
1110
# from rasterization.
1211
# NOTE: All blending function should return an RGBA image per batch element
@@ -117,7 +116,11 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
117116

118117

119118
def softmax_rgb_blend(
120-
colors, fragments, blend_params, znear: float = 1.0, zfar: float = 100
119+
colors,
120+
fragments,
121+
blend_params,
122+
znear: Union[float, torch.Tensor] = 1.0,
123+
zfar: Union[float, torch.Tensor] = 100,
121124
) -> torch.Tensor:
122125
"""
123126
RGB and alpha channel blending to return an RGBA image based on the method
@@ -184,11 +187,16 @@ def softmax_rgb_blend(
184187
# overflow. zbuf shape (N, H, W, K), find max over K.
185188
# TODO: there may still be some instability in the exponent calculation.
186189

190+
# Reshape to be compatible with (N, H, W, K) values in fragments
191+
if torch.is_tensor(zfar):
192+
# pyre-fixme[16]
193+
zfar = zfar[:, None, None, None]
194+
if torch.is_tensor(znear):
195+
znear = znear[:, None, None, None]
196+
187197
z_inv = (zfar - fragments.zbuf) / (zfar - znear) * mask
188198
# pyre-fixme[16]: `Tuple` has no attribute `values`.
189-
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`.
190199
z_inv_max = torch.max(z_inv, dim=-1).values[..., None].clamp(min=eps)
191-
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`.
192200
weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma)
193201

194202
# Also apply exp normalize trick for the background color weight.

pytorch3d/renderer/lighting.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -253,12 +253,26 @@ def clone(self):
253253
other = self.__class__(device=self.device)
254254
return super().clone(other)
255255

256+
def reshape_location(self, points) -> torch.Tensor:
257+
"""
258+
Reshape the location tensor to have dimensions
259+
compatible with the points which can either be of
260+
shape (P, 3) or (N, H, W, K, 3).
261+
"""
262+
if self.location.ndim == points.ndim:
263+
# pyre-fixme[7]
264+
return self.location
265+
# pyre-fixme[29]
266+
return self.location[:, None, None, None, :]
267+
256268
def diffuse(self, normals, points) -> torch.Tensor:
257-
direction = self.location - points
269+
location = self.reshape_location(points)
270+
direction = location - points
258271
return diffuse(normals=normals, color=self.diffuse_color, direction=direction)
259272

260273
def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
261-
direction = self.location - points
274+
location = self.reshape_location(points)
275+
direction = location - points
262276
return specular(
263277
points=points,
264278
normals=normals,

pytorch3d/renderer/mesh/shading.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ def _apply_lighting(
1414
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1515
"""
1616
Args:
17-
points: torch tensor of shape (N, P, 3) or (P, 3).
18-
normals: torch tensor of shape (N, P, 3) or (P, 3)
17+
points: torch tensor of shape (N, ..., 3) or (P, 3).
18+
normals: torch tensor of shape (N, ..., 3) or (P, 3)
1919
lights: instance of the Lights class.
2020
cameras: instance of the Cameras class.
2121
materials: instance of the Materials class.
@@ -35,13 +35,19 @@ def _apply_lighting(
3535
ambient_color = materials.ambient_color * lights.ambient_color
3636
diffuse_color = materials.diffuse_color * light_diffuse
3737
specular_color = materials.specular_color * light_specular
38+
3839
if normals.dim() == 2 and points.dim() == 2:
3940
# If given packed inputs remove batch dim in output.
4041
return (
4142
ambient_color.squeeze(),
4243
diffuse_color.squeeze(),
4344
specular_color.squeeze(),
4445
)
46+
47+
if ambient_color.ndim != diffuse_color.ndim:
48+
# Reshape from (N, 3) to have dimensions compatible with
49+
# diffuse_color which is of shape (N, H, W, K, 3)
50+
ambient_color = ambient_color[:, None, None, None, :]
4551
return ambient_color, diffuse_color, specular_color
4652

4753

tests/test_render_meshes.py

+36-22
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77
import os
88
import unittest
9+
from collections import namedtuple
910

1011
import numpy as np
1112
import torch
@@ -53,6 +54,8 @@
5354
DATA_DIR = get_tests_dir() / "data"
5455
TUTORIAL_DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"
5556

57+
ShaderTest = namedtuple("ShaderTest", ["shader", "reference_name", "debug_name"])
58+
5659

5760
class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
5861
def test_simple_sphere(self, elevated_camera=False, check_depth=False):
@@ -107,13 +110,13 @@ def test_simple_sphere(self, elevated_camera=False, check_depth=False):
107110
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
108111

109112
# Test several shaders
110-
shaders = {
111-
"phong": HardPhongShader,
112-
"gouraud": HardGouraudShader,
113-
"flat": HardFlatShader,
114-
}
115-
for (name, shader_init) in shaders.items():
116-
shader = shader_init(
113+
shader_tests = [
114+
ShaderTest(HardPhongShader, "phong", "hard_phong"),
115+
ShaderTest(HardGouraudShader, "gouraud", "hard_gouraud"),
116+
ShaderTest(HardFlatShader, "flat", "hard_flat"),
117+
]
118+
for test in shader_tests:
119+
shader = test.shader(
117120
lights=lights,
118121
cameras=cameras,
119122
materials=materials,
@@ -135,7 +138,7 @@ def test_simple_sphere(self, elevated_camera=False, check_depth=False):
135138

136139
rgb = images[0, ..., :3].squeeze().cpu()
137140
filename = "simple_sphere_light_%s%s%s.png" % (
138-
name,
141+
test.reference_name,
139142
postfix,
140143
cam_type.__name__,
141144
)
@@ -144,7 +147,12 @@ def test_simple_sphere(self, elevated_camera=False, check_depth=False):
144147
self.assertClose(rgb, image_ref, atol=0.05)
145148

146149
if DEBUG:
147-
filename = "DEBUG_%s" % filename
150+
debug_filename = "simple_sphere_light_%s%s%s.png" % (
151+
test.debug_name,
152+
postfix,
153+
cam_type.__name__,
154+
)
155+
filename = "DEBUG_%s" % debug_filename
148156
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
149157
DATA_DIR / filename
150158
)
@@ -269,7 +277,8 @@ def test_simple_sphere_screen(self):
269277
def test_simple_sphere_batched(self):
270278
"""
271279
Test a mesh with vertex textures can be extended to form a batch, and
272-
is rendered correctly with Phong, Gouraud and Flat Shaders.
280+
is rendered correctly with Phong, Gouraud and Flat Shaders with batched
281+
lighting and hard and soft blending.
273282
"""
274283
batch_size = 5
275284
device = torch.device("cuda:0")
@@ -291,24 +300,28 @@ def test_simple_sphere_batched(self):
291300
R, T = look_at_view_transform(dist, elev, azim)
292301
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
293302
raster_settings = RasterizationSettings(
294-
image_size=512, blur_radius=0.0, faces_per_pixel=1
303+
image_size=512, blur_radius=0.0, faces_per_pixel=4
295304
)
296305

297306
# Init shader settings
298307
materials = Materials(device=device)
299-
lights = PointLights(device=device)
300-
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
308+
lights_location = torch.tensor([0.0, 0.0, +2.0], device=device)
309+
lights_location = lights_location[None].expand(batch_size, -1)
310+
lights = PointLights(device=device, location=lights_location)
301311
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
302312

303313
# Init renderer
304314
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
305-
shaders = {
306-
"phong": HardPhongShader,
307-
"gouraud": HardGouraudShader,
308-
"flat": HardFlatShader,
309-
}
310-
for (name, shader_init) in shaders.items():
311-
shader = shader_init(
315+
shader_tests = [
316+
ShaderTest(HardPhongShader, "phong", "hard_phong"),
317+
ShaderTest(SoftPhongShader, "phong", "soft_phong"),
318+
ShaderTest(HardGouraudShader, "gouraud", "hard_gouraud"),
319+
ShaderTest(HardFlatShader, "flat", "hard_flat"),
320+
]
321+
for test in shader_tests:
322+
reference_name = test.reference_name
323+
debug_name = test.debug_name
324+
shader = test.shader(
312325
lights=lights,
313326
cameras=cameras,
314327
materials=materials,
@@ -317,14 +330,15 @@ def test_simple_sphere_batched(self):
317330
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
318331
images = renderer(sphere_meshes)
319332
image_ref = load_rgb_image(
320-
"test_simple_sphere_light_%s_%s.png" % (name, type(cameras).__name__),
333+
"test_simple_sphere_light_%s_%s.png"
334+
% (reference_name, type(cameras).__name__),
321335
DATA_DIR,
322336
)
323337
for i in range(batch_size):
324338
rgb = images[i, ..., :3].squeeze().cpu()
325339
if i == 0 and DEBUG:
326340
filename = "DEBUG_simple_sphere_batched_%s_%s.png" % (
327-
name,
341+
debug_name,
328342
type(cameras).__name__,
329343
)
330344
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(

0 commit comments

Comments
 (0)