Skip to content

Commit 01759d8

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Texture Atlas sampling bug fix
Summary: Fixes the index out of bound errors for texture sampling from a texture atlas: when barycentric coordinates are 1.0, the integer index into the (R, R) per face texture map is R (max can only be R-1). Reviewed By: gkioxari Differential Revision: D25543803 fbshipit-source-id: 82d0935b981352b49c1d95d5a17f9cc88bad0a82
1 parent 3d769a6 commit 01759d8

File tree

3 files changed

+58
-8
lines changed

3 files changed

+58
-8
lines changed

docs/notes/renderer_getting_started.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ For mesh texturing we offer several options (in `pytorch3d/renderer/mesh/texturi
8484

8585
1. **Vertex Textures**: D dimensional textures for each vertex (for example an RGB color) which can be interpolated across the face. This can be represented as an `(N, V, D)` tensor. This is a fairly simple representation though and cannot model complex textures if the mesh faces are large.
8686
2. **UV Textures**: vertex UV coordinates and **one** texture map for the whole mesh. For a point on a face with given barycentric coordinates, the face color can be computed by interpolating the vertex uv coordinates and then sampling from the texture map. This representation requires two tensors (UVs: `(N, V, 2), Texture map: `(N, H, W, 3)`), and is limited to only support one texture map per mesh.
87-
3. **Face Textures**: In more complex cases such as ShapeNet meshes, there are multiple texture maps per mesh and some faces have texture while other do not. For these cases, a more flexible representation is a texture atlas, where each face is represented as an `(RxR)` texture map where R is the texture resolution. For a given point on the face, the texture value can be sampled from the per face texture map using the barycentric coordinates of the point. This representation requires one tensor of shape `(N, F, R, R, 3)`. This texturing method is inspired by the SoftRasterizer implementation. For more details refer to the [`make_material_atlas`](https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/io/mtl_io.py#L123) and [`sample_textures`](https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/renderer/mesh/textures.py#L452) functions.
87+
3. **Face Textures**: In more complex cases such as ShapeNet meshes, there are multiple texture maps per mesh and some faces have texture while other do not. For these cases, a more flexible representation is a texture atlas, where each face is represented as an `(RxR)` texture map where R is the texture resolution. For a given point on the face, the texture value can be sampled from the per face texture map using the barycentric coordinates of the point. This representation requires one tensor of shape `(N, F, R, R, 3)`. This texturing method is inspired by the SoftRasterizer implementation. For more details refer to the [`make_material_atlas`](https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/io/mtl_io.py#L123) and [`sample_textures`](https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/renderer/mesh/textures.py#L452) functions. **NOTE:**: The `TextureAtlas` texture sampling is only differentiable with respect to the texture atlas but not differentiable with respect to the barycentric coordinates.
88+
8889

8990
<img src="assets/texturing.jpg" width="1000">
9091

pytorch3d/renderer/mesh/textures.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,18 @@ def extend(self, N: int) -> "TexturesAtlas":
479479

480480
def sample_textures(self, fragments, **kwargs) -> torch.Tensor:
481481
"""
482+
This is similar to a nearest neighbor sampling and involves a
483+
discretization step. The barycentric coordinates from
484+
rasterization are used to find the nearest grid cell in the texture
485+
atlas and the RGB is returned as the color.
486+
This means that this step is differentiable with respect to the RGB
487+
values of the texture atlas but not differentiable with respect to the
488+
barycentric coordinates.
489+
490+
TODO: Add a different sampling mode which interpolates the barycentric
491+
coordinates to sample the texture and will be differentiable w.r.t
492+
the barycentric coordinates.
493+
482494
Args:
483495
fragments:
484496
The outputs of rasterization. From this we use
@@ -504,7 +516,10 @@ def sample_textures(self, fragments, **kwargs) -> torch.Tensor:
504516
# pyre-fixme[16]: `bool` has no attribute `__getitem__`.
505517
mask = (pix_to_face < 0)[..., None]
506518
bary_w01 = torch.where(mask, torch.zeros_like(bary_w01), bary_w01)
507-
w_xy = (bary_w01 * R).to(torch.int64) # (N, H, W, K, 2)
519+
# If barycentric coordinates are > 1.0 (in the case of
520+
# blur_radius > 0.0), wxy might be > R. We need to clamp this
521+
# index to R-1 to index into the texture atlas.
522+
w_xy = (bary_w01 * R).to(torch.int64).clamp(max=R - 1) # (N, H, W, K, 2)
508523

509524
below_diag = (
510525
bary_w01.sum(dim=-1) * R - w_xy.float().sum(dim=-1)

tests/test_render_meshes.py

+40-6
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,7 @@ def test_joined_spheres(self):
956956
def test_texture_map_atlas(self):
957957
"""
958958
Test a mesh with a texture map as a per face atlas is loaded and rendered correctly.
959+
Also check that the backward pass for texture atlas rendering is differentiable.
959960
"""
960961
device = torch.device("cuda:0")
961962
obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
@@ -970,18 +971,22 @@ def test_texture_map_atlas(self):
970971
texture_atlas_size=8,
971972
texture_wrap=None,
972973
)
974+
atlas = aux.texture_atlas
973975
mesh = Meshes(
974976
verts=[verts],
975977
faces=[faces.verts_idx],
976-
textures=TexturesAtlas(atlas=[aux.texture_atlas]),
978+
textures=TexturesAtlas(atlas=[atlas]),
977979
)
978980

979981
# Init rasterizer settings
980982
R, T = look_at_view_transform(2.7, 0, 0)
981983
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
982984

983985
raster_settings = RasterizationSettings(
984-
image_size=512, blur_radius=0.0, faces_per_pixel=1, cull_backfaces=True
986+
image_size=512,
987+
blur_radius=0.0,
988+
faces_per_pixel=1,
989+
cull_backfaces=True,
985990
)
986991

987992
# Init shader settings
@@ -993,23 +998,52 @@ def test_texture_map_atlas(self):
993998
lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None]
994999

9951000
# The HardPhongShader can be used directly with atlas textures.
1001+
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
9961002
renderer = MeshRenderer(
997-
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
1003+
rasterizer=rasterizer,
9981004
shader=HardPhongShader(lights=lights, cameras=cameras, materials=materials),
9991005
)
10001006

10011007
images = renderer(mesh)
1002-
rgb = images[0, ..., :3].squeeze().cpu()
1008+
rgb = images[0, ..., :3].squeeze()
10031009

10041010
# Load reference image
10051011
image_ref = load_rgb_image("test_texture_atlas_8x8_back.png", DATA_DIR)
10061012

10071013
if DEBUG:
1008-
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
1014+
Image.fromarray((rgb.detach().cpu().numpy() * 255).astype(np.uint8)).save(
10091015
DATA_DIR / "DEBUG_texture_atlas_8x8_back.png"
10101016
)
10111017

1012-
self.assertClose(rgb, image_ref, atol=0.05)
1018+
self.assertClose(rgb.cpu(), image_ref, atol=0.05)
1019+
1020+
# Check gradients are propagated
1021+
# correctly back to the texture atlas.
1022+
# Because of how texture sampling is implemented
1023+
# for the texture atlas it is not possible to get
1024+
# gradients back to the vertices.
1025+
atlas.requires_grad = True
1026+
mesh = Meshes(
1027+
verts=[verts],
1028+
faces=[faces.verts_idx],
1029+
textures=TexturesAtlas(atlas=[atlas]),
1030+
)
1031+
raster_settings = RasterizationSettings(
1032+
image_size=512,
1033+
blur_radius=0.0001,
1034+
faces_per_pixel=5,
1035+
cull_backfaces=True,
1036+
clip_barycentric_coords=True,
1037+
)
1038+
images = renderer(mesh, raster_settings=raster_settings)
1039+
images[0, ...].sum().backward()
1040+
1041+
fragments = rasterizer(mesh, raster_settings=raster_settings)
1042+
# Some of the bary coordinates are outisde the
1043+
# [0, 1] range as expected because the blur is > 0
1044+
self.assertTrue(fragments.bary_coords.ge(1.0).any())
1045+
self.assertIsNotNone(atlas.grad)
1046+
self.assertTrue(atlas.grad.sum().abs() > 0.0)
10131047

10141048
def test_simple_sphere_outside_zfar(self):
10151049
"""

0 commit comments

Comments
 (0)