Skip to content

Commit a15c33a

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Alpha channel to return the mask
Summary: Updated the alpha channel in the `hard_rgb_blend` function to return the mask of the pixels which have overlapping mesh faces. Reviewed By: bottler Differential Revision: D29001604 fbshipit-source-id: 22a2173d769f2d3ad34892d68ceb628f073bca22
1 parent ac6c07f commit a15c33a

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

pytorch3d/renderer/blending.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
5858
) # (N, H, W, 3)
5959

6060
# Concat with the alpha channel.
61-
alpha = torch.ones((N, H, W, 1), dtype=colors.dtype, device=device)
61+
alpha = (~is_background).type_as(pixel_colors)[..., None]
62+
6263
return torch.cat([pixel_colors, alpha], dim=-1) # (N, H, W, 4)
6364

6465

tests/test_blending.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ def test_hard_rgb_blend(self):
184184
channel_color = blend_params.background_color[i]
185185
self.assertTrue(images[~is_foreground][..., i].eq(channel_color).all())
186186

187-
# Examine the alpha channel is correct
188-
self.assertTrue(images[..., 3].eq(1).all())
187+
# Examine the alpha channel
188+
self.assertClose(images[..., 3], (pix_to_face[..., 0] >= 0).float())
189189

190190
def test_sigmoid_alpha_blend_manual_gradients(self):
191191
# Create dummy outputs of rasterization

tests/test_render_meshes.py

+8
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ def test_simple_sphere(self, elevated_camera=False, check_depth=False):
125125
)
126126
images, fragments = renderer(sphere_mesh)
127127
self.assertClose(fragments.zbuf, rasterizer(sphere_mesh).zbuf)
128+
# Check the alpha channel is the mask
129+
self.assertClose(
130+
images[..., -1], (fragments.pix_to_face[..., 0] >= 0).float()
131+
)
128132
else:
129133
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
130134
images = renderer(sphere_mesh)
@@ -165,6 +169,10 @@ def test_simple_sphere(self, elevated_camera=False, check_depth=False):
165169
self.assertClose(
166170
fragments.zbuf, rasterizer(sphere_mesh, lights=lights).zbuf
167171
)
172+
# Check the alpha channel is the mask
173+
self.assertClose(
174+
images[..., -1], (fragments.pix_to_face[..., 0] >= 0).float()
175+
)
168176
else:
169177
phong_renderer = MeshRenderer(
170178
rasterizer=rasterizer, shader=phong_shader

0 commit comments

Comments
 (0)