Skip to content

Commit ef16253

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
textures dimension check
Summary: When textures are set on `Meshes` we need to check if the dimensions actually match that of the verts/faces in the mesh. There was a github issue where someone tried to set the attribute after construction of the `Meshes` object and ran into an error when trying to sample textures. The desired usage is to initialize the class with the textures (not set an attribute afterwards) but in any case we need to check the dimensions match before sampling textures. Reviewed By: bottler Differential Revision: D29020249 fbshipit-source-id: 9fb8a5368b83c9ec53652df92b96fc8b2613f591
1 parent 1cd1436 commit ef16253

File tree

3 files changed

+148
-7
lines changed

3 files changed

+148
-7
lines changed

pytorch3d/renderer/mesh/textures.py

+37
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,15 @@ def join_scene(self) -> "TexturesAtlas":
574574
"""
575575
return self.__class__(atlas=[torch.cat(self.atlas_list())])
576576

577+
def check_shapes(
578+
self, batch_size: int, max_num_verts: int, max_num_faces: int
579+
) -> bool:
580+
"""
581+
Check if the dimensions of the atlas match that of the mesh faces
582+
"""
583+
# (N, F) should be the same
584+
return self.atlas_padded().shape[0:2] == (batch_size, max_num_faces)
585+
577586

578587
class TexturesUV(TexturesBase):
579588
def __init__(
@@ -1213,6 +1222,18 @@ def centers_for_image(self, index: int) -> torch.Tensor:
12131222
centers = centers[0, :, 0].T
12141223
return centers
12151224

1225+
def check_shapes(
1226+
self, batch_size: int, max_num_verts: int, max_num_faces: int
1227+
) -> bool:
1228+
"""
1229+
Check if the dimensions of the verts/faces uvs match that of the mesh
1230+
"""
1231+
# (N, F) should be the same
1232+
# (N, V) is not guaranteed to be the same
1233+
return (self.faces_uvs_padded().shape[0:2] == (batch_size, max_num_faces)) and (
1234+
self.verts_uvs_padded().shape[0] == batch_size
1235+
)
1236+
12161237

12171238
class TexturesVertex(TexturesBase):
12181239
def __init__(
@@ -1292,6 +1313,13 @@ def __getitem__(self, index) -> "TexturesVertex":
12921313
new_props = self._getitem(index, props)
12931314
verts_features = new_props["verts_features_list"]
12941315
if isinstance(verts_features, list):
1316+
# Handle the case of an empty list
1317+
if len(verts_features) == 0:
1318+
verts_features = torch.empty(
1319+
size=(0, 0, 3),
1320+
dtype=torch.float32,
1321+
device=self.verts_features_padded().device,
1322+
)
12951323
new_tex = self.__class__(verts_features=verts_features)
12961324
elif torch.is_tensor(verts_features):
12971325
new_tex = self.__class__(verts_features=[verts_features])
@@ -1410,3 +1438,12 @@ def join_scene(self) -> "TexturesVertex":
14101438
Return a new TexturesVertex amalgamating the batch.
14111439
"""
14121440
return self.__class__(verts_features=[torch.cat(self.verts_features_list())])
1441+
1442+
def check_shapes(
1443+
self, batch_size: int, max_num_verts: int, max_num_faces: int
1444+
) -> bool:
1445+
"""
1446+
Check if the dimensions of the verts features match that of the mesh verts
1447+
"""
1448+
# (N, V) should be the same
1449+
return self.verts_features_padded().shape[:-1] == (batch_size, max_num_verts)

pytorch3d/structures/meshes.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ def __init__(
255255
if textures is not None and not hasattr(textures, "sample_textures"):
256256
msg = "Expected textures to be an instance of type TexturesBase; got %r"
257257
raise ValueError(msg % type(textures))
258+
258259
self.textures = textures
259260

260261
# Indicates whether the meshes in the list/batch have the same number
@@ -424,10 +425,14 @@ def __init__(
424425
)
425426

426427
# Set the num verts/faces on the textures if present.
427-
if self.textures is not None:
428+
if textures is not None:
429+
shape_ok = self.textures.check_shapes(self._N, self._V, self._F)
430+
if not shape_ok:
431+
msg = "Textures do not match the dimensions of Meshes."
432+
raise ValueError(msg)
433+
428434
self.textures._num_faces_per_mesh = self._num_faces_per_mesh.tolist()
429435
self.textures._num_verts_per_mesh = self._num_verts_per_mesh.tolist()
430-
self.textures._N = self._N
431436
self.textures.valid = self.valid
432437

433438
if verts_normals is not None:
@@ -1560,6 +1565,13 @@ def extend(self, N: int):
15601565

15611566
def sample_textures(self, fragments):
15621567
if self.textures is not None:
1568+
1569+
# Check dimensions of textures match that of meshes
1570+
shape_ok = self.textures.check_shapes(self._N, self._V, self._F)
1571+
if not shape_ok:
1572+
msg = "Textures do not match the dimensions of Meshes."
1573+
raise ValueError(msg)
1574+
15631575
# Pass in faces packed. If the textures are defined per
15641576
# vertex, the face indices are needed in order to interpolate
15651577
# the vertex attributes across the face.

tests/test_texturing.py

+97-5
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def test_padded_to_packed(self):
251251
def test_getitem(self):
252252
N = 5
253253
V = 20
254-
source = {"verts_features": torch.randn(size=(N, 10, 128))}
254+
source = {"verts_features": torch.randn(size=(N, V, 128))}
255255
tex = TexturesVertex(verts_features=source["verts_features"])
256256

257257
verts = torch.rand(size=(N, V, 3))
@@ -268,6 +268,30 @@ def test_getitem(self):
268268
tryindex(self, index, tex, meshes, source)
269269
tryindex(self, [2, 4], tex, meshes, source)
270270

271+
def test_sample_textures_error(self):
272+
N = 5
273+
V = 20
274+
verts = torch.rand(size=(N, V, 3))
275+
faces = torch.randint(size=(N, 10, 3), high=V)
276+
tex = TexturesVertex(verts_features=torch.randn(size=(N, 10, 128)))
277+
278+
# Verts features have the wrong number of verts
279+
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
280+
Meshes(verts=verts, faces=faces, textures=tex)
281+
282+
# Verts features have the wrong batch dim
283+
tex = TexturesVertex(verts_features=torch.randn(size=(1, V, 128)))
284+
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
285+
Meshes(verts=verts, faces=faces, textures=tex)
286+
287+
meshes = Meshes(verts=verts, faces=faces)
288+
meshes.textures = tex
289+
290+
# Cannot use the texture attribute set on meshes for sampling
291+
# textures if the dimensions don't match
292+
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
293+
meshes.sample_textures(None)
294+
271295

272296
class TestTexturesAtlas(TestCaseMixin, unittest.TestCase):
273297
def test_sample_texture_atlas(self):
@@ -456,11 +480,12 @@ def test_padded_to_packed(self):
456480
def test_getitem(self):
457481
N = 5
458482
V = 20
459-
source = {"atlas": torch.randn(size=(N, 10, 4, 4, 3))}
483+
F = 10
484+
source = {"atlas": torch.randn(size=(N, F, 4, 4, 3))}
460485
tex = TexturesAtlas(atlas=source["atlas"])
461486

462487
verts = torch.rand(size=(N, V, 3))
463-
faces = torch.randint(size=(N, 10, 3), high=V)
488+
faces = torch.randint(size=(N, F, 3), high=V)
464489
meshes = Meshes(verts=verts, faces=faces, textures=tex)
465490

466491
tryindex(self, 2, tex, meshes, source)
@@ -473,6 +498,32 @@ def test_getitem(self):
473498
tryindex(self, index, tex, meshes, source)
474499
tryindex(self, [2, 4], tex, meshes, source)
475500

501+
def test_sample_textures_error(self):
502+
N = 1
503+
V = 20
504+
F = 10
505+
verts = torch.rand(size=(5, V, 3))
506+
faces = torch.randint(size=(5, F, 3), high=V)
507+
meshes = Meshes(verts=verts, faces=faces)
508+
509+
# TexturesAtlas have the wrong batch dim
510+
tex = TexturesAtlas(atlas=torch.randn(size=(1, F, 4, 4, 3)))
511+
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
512+
Meshes(verts=verts, faces=faces, textures=tex)
513+
514+
# TexturesAtlas have the wrong number of faces
515+
tex = TexturesAtlas(atlas=torch.randn(size=(N, 15, 4, 4, 3)))
516+
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
517+
Meshes(verts=verts, faces=faces, textures=tex)
518+
519+
meshes = Meshes(verts=verts, faces=faces)
520+
meshes.textures = tex
521+
522+
# Cannot use the texture attribute set on meshes for sampling
523+
# textures if the dimensions don't match
524+
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
525+
meshes.sample_textures(None)
526+
476527

477528
class TestTexturesUV(TestCaseMixin, unittest.TestCase):
478529
def setUp(self) -> None:
@@ -824,9 +875,10 @@ def test_mesh_to(self):
824875
def test_getitem(self):
825876
N = 5
826877
V = 20
878+
F = 10
827879
source = {
828880
"maps": torch.rand(size=(N, 1, 1, 3)),
829-
"faces_uvs": torch.randint(size=(N, 10, 3), high=V),
881+
"faces_uvs": torch.randint(size=(N, F, 3), high=V),
830882
"verts_uvs": torch.randn(size=(N, V, 2)),
831883
}
832884
tex = TexturesUV(
@@ -836,7 +888,7 @@ def test_getitem(self):
836888
)
837889

838890
verts = torch.rand(size=(N, V, 3))
839-
faces = torch.randint(size=(N, 10, 3), high=V)
891+
faces = torch.randint(size=(N, F, 3), high=V)
840892
meshes = Meshes(verts=verts, faces=faces, textures=tex)
841893

842894
tryindex(self, 2, tex, meshes, source)
@@ -858,6 +910,46 @@ def test_centers_for_image(self):
858910
expected = torch.FloatTensor([[32, 224], [64, 96], [64, 128]])
859911
self.assertClose(tex.centers_for_image(0), expected)
860912

913+
def test_sample_textures_error(self):
914+
N = 1
915+
V = 20
916+
F = 10
917+
maps = torch.rand(size=(N, 1, 1, 3))
918+
verts_uvs = torch.randn(size=(N, V, 2))
919+
tex = TexturesUV(
920+
maps=maps,
921+
faces_uvs=torch.randint(size=(N, 15, 3), high=V),
922+
verts_uvs=verts_uvs,
923+
)
924+
verts = torch.rand(size=(5, V, 3))
925+
faces = torch.randint(size=(5, 10, 3), high=V)
926+
meshes = Meshes(verts=verts, faces=faces)
927+
928+
# Wrong number of faces
929+
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
930+
Meshes(verts=verts, faces=faces, textures=tex)
931+
932+
# Wrong batch dim for faces
933+
tex = TexturesUV(
934+
maps=maps,
935+
faces_uvs=torch.randint(size=(1, F, 3), high=V),
936+
verts_uvs=verts_uvs,
937+
)
938+
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
939+
Meshes(verts=verts, faces=faces, textures=tex)
940+
941+
# Wrong batch dim for verts_uvs is not necessary to check as
942+
# there is already a check inside TexturesUV for a batch dim
943+
# mismatch with faces_uvs
944+
945+
meshes = Meshes(verts=verts, faces=faces)
946+
meshes.textures = tex
947+
948+
# Cannot use the texture attribute set on meshes for sampling
949+
# textures if the dimensions don't match
950+
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
951+
meshes.sample_textures(None)
952+
861953

862954
class TestRectanglePacking(TestCaseMixin, unittest.TestCase):
863955
def setUp(self) -> None:

0 commit comments

Comments
 (0)