Skip to content

Commit 5d3cc35

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Rendering texturing fixes
Summary: Fix errors raised by issue on GitHub - extending mesh textures + rendering with Gourad and Phong shaders. #97 Reviewed By: gkioxari Differential Revision: D20319610 fbshipit-source-id: d1c692ff0b9397a77a9b829c5c731790de70c09f
1 parent f580ce1 commit 5d3cc35

10 files changed

+350
-154
lines changed

pytorch3d/renderer/mesh/texturing.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def interpolate_vertex_colors(fragments, meshes) -> torch.Tensor:
107107
There will be one C dimensional value for each element in
108108
fragments.pix_to_face.
109109
"""
110-
vertex_textures = meshes.textures.verts_rgb_padded().view(-1, 3) # (V, C)
110+
vertex_textures = meshes.textures.verts_rgb_padded().reshape(
111+
-1, 3
112+
) # (V, C)
111113
vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :]
112114
faces_packed = meshes.faces_packed()
113115
faces_textures = vertex_textures[faces_packed] # (F, 3, C)

pytorch3d/renderer/utils.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -223,27 +223,32 @@ def gather_props(self, batch_idx):
223223
self with all properties reshaped. e.g. a property with shape (N, 3)
224224
is transformed to shape (B, 3).
225225
"""
226+
# Iterate through the attributes of the class which are tensors.
226227
for k in dir(self):
227228
v = getattr(self, k)
228229
if torch.is_tensor(v):
229230
if v.shape[0] > 1:
230231
# There are different values for each batch element
231-
# so gather these using the batch_idx
232-
idx_dims = batch_idx.shape
232+
# so gather these using the batch_idx.
233+
# First clone the input batch_idx tensor before
234+
# modifying it.
235+
_batch_idx = batch_idx.clone()
236+
idx_dims = _batch_idx.shape
233237
tensor_dims = v.shape
234238
if len(idx_dims) > len(tensor_dims):
235239
msg = "batch_idx cannot have more dimensions than %s. "
236240
msg += "got shape %r and %s has shape %r"
237241
raise ValueError(msg % (k, idx_dims, k, tensor_dims))
238242
if idx_dims != tensor_dims:
239-
# To use torch.gather the index tensor (batch_idx) has
243+
# To use torch.gather the index tensor (_batch_idx) has
240244
# to have the same shape as the input tensor.
241245
new_dims = len(tensor_dims) - len(idx_dims)
242246
new_shape = idx_dims + (1,) * new_dims
243247
expand_dims = (-1,) + tensor_dims[1:]
244-
batch_idx = batch_idx.view(*new_shape)
245-
batch_idx = batch_idx.expand(*expand_dims)
246-
v = v.gather(0, batch_idx)
248+
_batch_idx = _batch_idx.view(*new_shape)
249+
_batch_idx = _batch_idx.expand(*expand_dims)
250+
251+
v = v.gather(0, _batch_idx)
247252
setattr(self, k, v)
248253
return self
249254

pytorch3d/structures/meshes.py

+36-13
Original file line numberDiff line numberDiff line change
@@ -324,14 +324,14 @@ def __init__(self, verts=None, faces=None, textures=None):
324324
)
325325
if self._N > 0:
326326
self.device = self._verts_list[0].device
327-
num_verts_per_mesh = torch.tensor(
327+
self._num_verts_per_mesh = torch.tensor(
328328
[len(v) for v in self._verts_list], device=self.device
329329
)
330-
self._V = num_verts_per_mesh.max()
331-
num_faces_per_mesh = torch.tensor(
330+
self._V = self._num_verts_per_mesh.max()
331+
self._num_faces_per_mesh = torch.tensor(
332332
[len(f) for f in self._faces_list], device=self.device
333333
)
334-
self._F = num_faces_per_mesh.max()
334+
self._F = self._num_faces_per_mesh.max()
335335
self.valid = torch.tensor(
336336
[
337337
len(v) > 0 and len(f) > 0
@@ -341,8 +341,8 @@ def __init__(self, verts=None, faces=None, textures=None):
341341
device=self.device,
342342
)
343343

344-
if (len(num_verts_per_mesh.unique()) == 1) and (
345-
len(num_faces_per_mesh.unique()) == 1
344+
if (len(self._num_verts_per_mesh.unique()) == 1) and (
345+
len(self._num_faces_per_mesh.unique()) == 1
346346
):
347347
self.equisized = True
348348

@@ -355,6 +355,7 @@ def __init__(self, verts=None, faces=None, textures=None):
355355
self._faces_padded = faces.to(torch.int64)
356356
self._N = self._verts_padded.shape[0]
357357
self._V = self._verts_padded.shape[1]
358+
358359
self.device = self._verts_padded.device
359360
self.valid = torch.zeros(
360361
(self._N,), dtype=torch.bool, device=self.device
@@ -363,25 +364,49 @@ def __init__(self, verts=None, faces=None, textures=None):
363364
# Check that padded faces - which have value -1 - are at the
364365
# end of the tensors
365366
faces_not_padded = self._faces_padded.gt(-1).all(2)
366-
num_faces = faces_not_padded.sum(1)
367+
self._num_faces_per_mesh = faces_not_padded.sum(1)
367368
if (faces_not_padded[:, :-1] < faces_not_padded[:, 1:]).any():
368369
raise ValueError("Padding of faces must be at the end")
369370

370371
# NOTE that we don't check for the ordering of padded verts
371372
# as long as the faces index correspond to the right vertices.
372373

373-
self.valid = num_faces > 0
374-
self._F = num_faces.max()
375-
if len(num_faces.unique()) == 1:
374+
self.valid = self._num_faces_per_mesh > 0
375+
self._F = self._num_faces_per_mesh.max()
376+
if len(self._num_faces_per_mesh.unique()) == 1:
376377
self.equisized = True
377378

379+
self._num_verts_per_mesh = torch.full(
380+
size=(self._N,),
381+
fill_value=self._V,
382+
dtype=torch.int64,
383+
device=self.device,
384+
)
385+
378386
else:
379387
raise ValueError(
380388
"Verts and Faces must be either a list or a tensor with \
381389
shape (batch_size, N, 3) where N is either the maximum \
382390
number of verts or faces respectively."
383391
)
384392

393+
if self.isempty():
394+
self._num_verts_per_mesh = torch.zeros(
395+
(0,), dtype=torch.int64, device=self.device
396+
)
397+
self._num_faces_per_mesh = torch.zeros(
398+
(0,), dtype=torch.int64, device=self.device
399+
)
400+
401+
# Set the num verts/faces on the textures if present.
402+
if self.textures is not None:
403+
self.textures._num_faces_per_mesh = (
404+
self._num_faces_per_mesh.tolist()
405+
)
406+
self.textures._num_verts_per_mesh = (
407+
self._num_verts_per_mesh.tolist()
408+
)
409+
385410
def __len__(self):
386411
return self._N
387412

@@ -893,11 +918,9 @@ def _compute_packed(self, refresh: bool = False):
893918
self._verts_packed,
894919
self._verts_packed_to_mesh_idx,
895920
self._mesh_to_verts_packed_first_idx,
896-
self._num_verts_per_mesh,
897921
self._faces_packed,
898922
self._faces_packed_to_mesh_idx,
899923
self._mesh_to_faces_packed_first_idx,
900-
self._num_faces_per_mesh,
901924
]
902925
)
903926
):
@@ -920,7 +943,6 @@ def _compute_packed(self, refresh: bool = False):
920943
self._num_verts_per_mesh = torch.zeros(
921944
(0,), dtype=torch.int64, device=self.device
922945
)
923-
924946
self._faces_packed = -torch.ones(
925947
(0, 3), dtype=torch.int64, device=self.device
926948
)
@@ -1354,6 +1376,7 @@ def extend(self, N: int):
13541376
tex = None
13551377
if self.textures is not None:
13561378
tex = self.textures.extend(N)
1379+
13571380
return Meshes(verts=new_verts_list, faces=new_faces_list, textures=tex)
13581381

13591382

pytorch3d/structures/textures.py

+64-33
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
import torchvision.transforms as T
66

7-
from .utils import list_to_packed, padded_to_list
7+
from .utils import padded_to_list, padded_to_packed
88

99

1010
"""
@@ -92,14 +92,19 @@ def __init__(
9292
faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each
9393
vertex in the face. Padding value is assumed to be -1.
9494
verts_uvs: (N, V, 2) tensor giving the uv coordinate per vertex.
95-
verts_rgb: (N, V, 3) tensor giving the rgb color per vertex.
95+
verts_rgb: (N, V, 3) tensor giving the rgb color per vertex. Padding
96+
value is assumed to be -1.
97+
98+
Note: only the padded representation of the textures is stored
99+
and the packed/list representations are computed on the fly and
100+
not cached.
96101
"""
97102
if faces_uvs is not None and faces_uvs.ndim != 3:
98103
msg = "Expected faces_uvs to be of shape (N, F, 3); got %r"
99104
raise ValueError(msg % repr(faces_uvs.shape))
100105
if verts_uvs is not None and verts_uvs.ndim != 3:
101106
msg = "Expected verts_uvs to be of shape (N, V, 2); got %r"
102-
raise ValueError(msg % repr(faces_uvs.shape))
107+
raise ValueError(msg % repr(verts_uvs.shape))
103108
if verts_rgb is not None and verts_rgb.ndim != 3:
104109
msg = "Expected verts_rgb to be of shape (N, V, 3); got %r"
105110
raise ValueError(msg % repr(verts_rgb.shape))
@@ -109,20 +114,20 @@ def __init__(
109114
raise ValueError(msg % repr(maps.shape))
110115
elif isinstance(maps, list):
111116
maps = _pad_texture_maps(maps)
117+
if faces_uvs is None or verts_uvs is None:
118+
msg = "To use maps, faces_uvs and verts_uvs are required"
119+
raise ValueError(msg)
120+
112121
self._faces_uvs_padded = faces_uvs
113122
self._verts_uvs_padded = verts_uvs
114123
self._verts_rgb_padded = verts_rgb
115124
self._maps_padded = maps
116-
self._num_faces_per_mesh = None
117-
self._set_num_faces_per_mesh()
118125

119-
def _set_num_faces_per_mesh(self) -> None:
120-
"""
121-
Determines and sets the number of textured faces for each mesh.
122-
"""
123-
if self._faces_uvs_padded is not None:
124-
faces_uvs = self._faces_uvs_padded
125-
self._num_faces_per_mesh = faces_uvs.gt(-1).all(-1).sum(-1).tolist()
126+
# The number of faces/verts for each mesh is
127+
# set inside the Meshes object when textures is
128+
# passed into the Meshes constructor.
129+
self._num_faces_per_mesh = None
130+
self._num_verts_per_mesh = None
126131

127132
def clone(self):
128133
other = Textures()
@@ -148,41 +153,67 @@ def __getitem__(self, index):
148153
setattr(other, key, value[index][None])
149154
else:
150155
setattr(other, key, value[index])
151-
other._set_num_faces_per_mesh()
152156
return other
153157

154158
def faces_uvs_padded(self) -> torch.Tensor:
155159
return self._faces_uvs_padded
156160

157-
def faces_uvs_list(self) -> List[torch.Tensor]:
158-
if self._faces_uvs_padded is not None:
159-
return padded_to_list(
160-
self._faces_uvs_padded, split_size=self._num_faces_per_mesh
161-
)
162-
163-
def faces_uvs_packed(self) -> torch.Tensor:
164-
return list_to_packed(self.faces_uvs_list())[0]
165-
166-
def verts_uvs_padded(self) -> torch.Tensor:
161+
def faces_uvs_list(self) -> Union[List[torch.Tensor], None]:
162+
if self._faces_uvs_padded is None:
163+
return None
164+
return padded_to_list(
165+
self._faces_uvs_padded, split_size=self._num_faces_per_mesh
166+
)
167+
168+
def faces_uvs_packed(self) -> Union[torch.Tensor, None]:
169+
if self._faces_uvs_padded is None:
170+
return None
171+
return padded_to_packed(
172+
self._faces_uvs_padded, split_size=self._num_faces_per_mesh
173+
)
174+
175+
def verts_uvs_padded(self) -> Union[torch.Tensor, None]:
167176
return self._verts_uvs_padded
168177

169-
def verts_uvs_list(self) -> List[torch.Tensor]:
178+
def verts_uvs_list(self) -> Union[List[torch.Tensor], None]:
179+
if self._verts_uvs_padded is None:
180+
return None
181+
# Vertices shared between multiple faces
182+
# may have a different uv coordinate for
183+
# each face so the num_verts_uvs_per_mesh
184+
# may be different from num_verts_per_mesh.
185+
# Therefore don't use any split_size.
170186
return padded_to_list(self._verts_uvs_padded)
171187

172-
def verts_uvs_packed(self) -> torch.Tensor:
173-
return list_to_packed(self.verts_uvs_list())[0]
174-
175-
def verts_rgb_padded(self) -> torch.Tensor:
188+
def verts_uvs_packed(self) -> Union[torch.Tensor, None]:
189+
if self._verts_uvs_padded is None:
190+
return None
191+
# Vertices shared between multiple faces
192+
# may have a different uv coordinate for
193+
# each face so the num_verts_uvs_per_mesh
194+
# may be different from num_verts_per_mesh.
195+
# Therefore don't use any split_size.
196+
return padded_to_packed(self._verts_uvs_padded)
197+
198+
def verts_rgb_padded(self) -> Union[torch.Tensor, None]:
176199
return self._verts_rgb_padded
177200

178-
def verts_rgb_list(self) -> List[torch.Tensor]:
179-
return padded_to_list(self._verts_rgb_padded)
201+
def verts_rgb_list(self) -> Union[List[torch.Tensor], None]:
202+
if self._verts_rgb_padded is None:
203+
return None
204+
return padded_to_list(
205+
self._verts_rgb_padded, split_size=self._num_verts_per_mesh
206+
)
180207

181-
def verts_rgb_packed(self) -> torch.Tensor:
182-
return list_to_packed(self.verts_rgb_list())[0]
208+
def verts_rgb_packed(self) -> Union[torch.Tensor, None]:
209+
if self._verts_rgb_padded is None:
210+
return None
211+
return padded_to_packed(
212+
self._verts_rgb_padded, split_size=self._num_verts_per_mesh
213+
)
183214

184215
# Currently only the padded maps are used.
185-
def maps_padded(self) -> torch.Tensor:
216+
def maps_padded(self) -> Union[torch.Tensor, None]:
186217
return self._maps_padded
187218

188219
def extend(self, N: int) -> "Textures":

tests/test_meshes.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -135,16 +135,22 @@ def init_simple_mesh(device: str = "cpu"):
135135
def test_simple(self):
136136
mesh = TestMeshes.init_simple_mesh("cuda:0")
137137

138+
# Check that faces/verts per mesh are set in init:
139+
self.assertClose(
140+
mesh._num_faces_per_mesh.cpu(), torch.tensor([1, 2, 7])
141+
)
142+
self.assertClose(
143+
mesh._num_verts_per_mesh.cpu(), torch.tensor([3, 4, 5])
144+
)
145+
146+
# Check computed tensors
138147
self.assertClose(
139148
mesh.verts_packed_to_mesh_idx().cpu(),
140149
torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]),
141150
)
142151
self.assertClose(
143152
mesh.mesh_to_verts_packed_first_idx().cpu(), torch.tensor([0, 3, 7])
144153
)
145-
self.assertClose(
146-
mesh.num_verts_per_mesh().cpu(), torch.tensor([3, 4, 5])
147-
)
148154
self.assertClose(
149155
mesh.verts_padded_to_packed_idx().cpu(),
150156
torch.tensor([0, 1, 2, 5, 6, 7, 8, 10, 11, 12, 13, 14]),
@@ -156,9 +162,6 @@ def test_simple(self):
156162
self.assertClose(
157163
mesh.mesh_to_faces_packed_first_idx().cpu(), torch.tensor([0, 1, 3])
158164
)
159-
self.assertClose(
160-
mesh.num_faces_per_mesh().cpu(), torch.tensor([1, 2, 7])
161-
)
162165
self.assertClose(
163166
mesh.num_edges_per_mesh().cpu(),
164167
torch.tensor([3, 5, 10], dtype=torch.int32),
@@ -249,6 +252,8 @@ def test_allempty(self):
249252
self.assertEqual(mesh.faces_padded().shape[0], 0)
250253
self.assertEqual(mesh.verts_packed().shape[0], 0)
251254
self.assertEqual(mesh.faces_packed().shape[0], 0)
255+
self.assertEqual(mesh.num_faces_per_mesh().shape[0], 0)
256+
self.assertEqual(mesh.num_verts_per_mesh().shape[0], 0)
252257

253258
def test_empty(self):
254259
N, V, F = 10, 100, 300
@@ -323,9 +328,11 @@ def test_padding(self):
323328

324329
mesh = Meshes(verts=torch.stack(verts), faces=torch.stack(faces))
325330

331+
# Check verts/faces per mesh are set correctly in init.
326332
self.assertListEqual(
327-
mesh.num_faces_per_mesh().tolist(), num_faces.tolist()
333+
mesh._num_faces_per_mesh.tolist(), num_faces.tolist()
328334
)
335+
self.assertListEqual(mesh._num_verts_per_mesh.tolist(), [V] * N)
329336

330337
for n, (vv, ff) in enumerate(zip(mesh.verts_list(), mesh.faces_list())):
331338
self.assertClose(ff, faces[n][: num_faces[n]])
@@ -364,7 +371,6 @@ def test_clone(self):
364371
mesh._num_verts_per_mesh = torch.randint_like(
365372
mesh.num_verts_per_mesh(), high=10
366373
)
367-
368374
# Check cloned and original Meshes objects do not share tensors.
369375
self.assertFalse(
370376
torch.allclose(new_mesh._verts_list[0], mesh._verts_list[0])

0 commit comments

Comments
 (0)