Skip to content

Commit 6c3fe95

Browse files
bottlerfacebook-github-bot
authored andcommitted
PLY TexturesVertex loading
Summary: Include TexturesVertex colors when loading and saving Meshes to PLY files. A couple of other improvements to the internals of ply_io, including using `None` instead of empty tensors for some missing data. Reviewed By: gkioxari Differential Revision: D27765260 fbshipit-source-id: b9857dc777c244b9d7d6643b608596d31435ecda
1 parent 097b0ef commit 6c3fe95

File tree

2 files changed

+135
-53
lines changed

2 files changed

+135
-53
lines changed

pytorch3d/io/ply_io.py

+81-52
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
from iopath.common.file_io import PathManager
2222
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
23+
from pytorch3d.renderer import TexturesVertex
2324
from pytorch3d.structures import Meshes, Pointclouds
2425

2526
from .pluggable_formats import (
@@ -66,7 +67,7 @@ class _PlyElementType:
6667
def __init__(self, name: str, count: int):
6768
self.name = name
6869
self.count = count
69-
self.properties = []
70+
self.properties: List[_Property] = []
7071

7172
def add_property(
7273
self, name: str, data_type: str, list_size_type: Optional[str] = None
@@ -142,7 +143,7 @@ def __init__(self, f):
142143
if f.readline() not in [b"ply\n", b"ply\r\n", "ply\n"]:
143144
raise ValueError("Invalid file header.")
144145
seen_format = False
145-
self.elements = []
146+
self.elements: List[_PlyElementType] = []
146147
self.obj_info = []
147148
while True:
148149
line = f.readline()
@@ -891,8 +892,8 @@ def _get_verts(
891892

892893

893894
def _load_ply(
894-
f, *, path_manager: PathManager, return_vertex_colors: bool = False
895-
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
895+
f, *, path_manager: PathManager
896+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
896897
"""
897898
Load the data from a .ply file.
898899
@@ -903,12 +904,11 @@ def _load_ply(
903904
ply format, then a text stream is not supported.
904905
It is easiest to use a binary stream in all cases.
905906
path_manager: PathManager for loading if f is a str.
906-
return_vertex_colors: whether to return vertex colors.
907907
908908
Returns:
909909
verts: FloatTensor of shape (V, 3).
910910
faces: None or LongTensor of vertex indices, shape (F, 3).
911-
vertex_colors: None or FloatTensor of shape (V, 3), only if requested
911+
vertex_colors: None or FloatTensor of shape (V, 3).
912912
"""
913913
header, elements = _load_ply_raw(f, path_manager=path_manager)
914914

@@ -950,16 +950,17 @@ def _load_ply(
950950
if faces is not None:
951951
_check_faces_indices(faces, max_index=verts.shape[0])
952952

953-
if return_vertex_colors:
954-
return verts, faces, vertex_colors
955-
return verts, faces, None
953+
return verts, faces, vertex_colors
956954

957955

958956
def load_ply(
959957
f, *, path_manager: Optional[PathManager] = None
960958
) -> Tuple[torch.Tensor, torch.Tensor]:
961959
"""
962-
Load the data from a .ply file.
960+
Load the verts and faces from a .ply file.
961+
Note that the preferred way to load data from such a file
962+
is to use the IO.load_mesh and IO.load_pointcloud functions,
963+
which can read more of the data.
963964
964965
Example .ply file format:
965966
@@ -1016,8 +1017,8 @@ def _save_ply(
10161017
*,
10171018
verts: torch.Tensor,
10181019
faces: Optional[torch.LongTensor],
1019-
verts_normals: torch.Tensor,
1020-
verts_colors: torch.Tensor,
1020+
verts_normals: Optional[torch.Tensor],
1021+
verts_colors: Optional[torch.Tensor],
10211022
ascii: bool,
10221023
decimal_places: Optional[int] = None,
10231024
) -> None:
@@ -1029,16 +1030,16 @@ def _save_ply(
10291030
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
10301031
faces: LongTensor of shape (F, 3) giving faces.
10311032
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
1033+
verts_colors: FloatTensor of shape (V, 3) giving vertex colors.
10321034
ascii: (bool) whether to use the ascii ply format.
10331035
decimal_places: Number of decimal places for saving if ascii=True.
10341036
"""
10351037
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
1036-
if faces is not None:
1037-
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
1038-
assert not len(verts_normals) or (
1038+
assert faces is None or not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
1039+
assert verts_normals is None or (
10391040
verts_normals.dim() == 2 and verts_normals.size(1) == 3
10401041
)
1041-
assert not len(verts_colors) or (
1042+
assert verts_colors is None or (
10421043
verts_colors.dim() == 2 and verts_colors.size(1) == 3
10431044
)
10441045

@@ -1052,11 +1053,11 @@ def _save_ply(
10521053
f.write(b"property float x\n")
10531054
f.write(b"property float y\n")
10541055
f.write(b"property float z\n")
1055-
if verts_normals.numel() > 0:
1056+
if verts_normals is not None:
10561057
f.write(b"property float nx\n")
10571058
f.write(b"property float ny\n")
10581059
f.write(b"property float nz\n")
1059-
if verts_colors.numel() > 0:
1060+
if verts_colors is not None:
10601061
f.write(b"property float red\n")
10611062
f.write(b"property float green\n")
10621063
f.write(b"property float blue\n")
@@ -1069,7 +1070,13 @@ def _save_ply(
10691070
warnings.warn("Empty 'verts' provided")
10701071
return
10711072

1072-
vert_data = torch.cat((verts, verts_normals, verts_colors), dim=1).detach().numpy()
1073+
verts_tensors = [verts]
1074+
if verts_normals is not None:
1075+
verts_tensors.append(verts_normals)
1076+
if verts_colors is not None:
1077+
verts_tensors.append(verts_colors)
1078+
1079+
vert_data = torch.cat(verts_tensors, dim=1).detach().cpu().numpy()
10731080
if ascii:
10741081
if decimal_places is None:
10751082
float_str = "%f"
@@ -1085,7 +1092,7 @@ def _save_ply(
10851092
vert_data.tofile(f)
10861093

10871094
if faces is not None:
1088-
faces_array = faces.detach().numpy()
1095+
faces_array = faces.detach().cpu().numpy()
10891096

10901097
_check_faces_indices(faces, max_index=verts.shape[0])
10911098

@@ -1125,12 +1132,6 @@ def save_ply(
11251132
11261133
"""
11271134

1128-
verts_normals = (
1129-
torch.tensor([], dtype=torch.float32, device=verts.device)
1130-
if verts_normals is None
1131-
else verts_normals
1132-
)
1133-
11341135
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
11351136
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
11361137
raise ValueError(message)
@@ -1143,16 +1144,18 @@ def save_ply(
11431144
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
11441145
raise ValueError(message)
11451146

1146-
if len(verts_normals) and not (
1147-
verts_normals.dim() == 2
1148-
and verts_normals.size(1) == 3
1149-
and verts_normals.size(0) == verts.size(0)
1147+
if (
1148+
verts_normals is not None
1149+
and len(verts_normals)
1150+
and not (
1151+
verts_normals.dim() == 2
1152+
and verts_normals.size(1) == 3
1153+
and verts_normals.size(0) == verts.size(0)
1154+
)
11501155
):
11511156
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
11521157
raise ValueError(message)
11531158

1154-
verts_colors = torch.FloatTensor([])
1155-
11561159
if path_manager is None:
11571160
path_manager = PathManager()
11581161
with _open_file(f, path_manager, "wb") as f:
@@ -1161,7 +1164,7 @@ def save_ply(
11611164
verts=verts,
11621165
faces=faces,
11631166
verts_normals=verts_normals,
1164-
verts_colors=verts_colors,
1167+
verts_colors=None,
11651168
ascii=ascii,
11661169
decimal_places=decimal_places,
11671170
)
@@ -1182,8 +1185,19 @@ def read(
11821185
if not endswith(path, self.known_suffixes):
11831186
return None
11841187

1185-
verts, faces = load_ply(f=path, path_manager=path_manager)
1186-
mesh = Meshes(verts=[verts.to(device)], faces=[faces.to(device)])
1188+
verts, faces, verts_colors = _load_ply(f=path, path_manager=path_manager)
1189+
if faces is None:
1190+
faces = torch.zeros(0, 3, dtype=torch.int64)
1191+
1192+
textures = None
1193+
if include_textures and verts_colors is not None:
1194+
textures = TexturesVertex([verts_colors.to(device)])
1195+
1196+
mesh = Meshes(
1197+
verts=[verts.to(device)],
1198+
faces=[faces.to(device)],
1199+
textures=textures,
1200+
)
11871201
return mesh
11881202

11891203
def save(
@@ -1201,14 +1215,30 @@ def save(
12011215
# TODO: normals are not saved. We only want to save them if they already exist.
12021216
verts = data.verts_list()[0]
12031217
faces = data.faces_list()[0]
1204-
save_ply(
1205-
f=path,
1206-
verts=verts,
1207-
faces=faces,
1208-
ascii=binary is False,
1209-
decimal_places=decimal_places,
1210-
path_manager=path_manager,
1211-
)
1218+
1219+
if isinstance(data.textures, TexturesVertex):
1220+
mesh_verts_colors = data.textures.verts_features_list()[0]
1221+
n_colors = mesh_verts_colors.shape[1]
1222+
if n_colors == 3:
1223+
verts_colors = mesh_verts_colors
1224+
else:
1225+
warnings.warn(
1226+
f"Texture will not be saved as it has {n_colors} colors, not 3."
1227+
)
1228+
verts_colors = None
1229+
else:
1230+
verts_colors = None
1231+
1232+
with _open_file(path, path_manager, "wb") as f:
1233+
_save_ply(
1234+
f=f,
1235+
verts=verts,
1236+
faces=faces,
1237+
verts_colors=verts_colors,
1238+
verts_normals=None,
1239+
ascii=binary is False,
1240+
decimal_places=decimal_places,
1241+
)
12121242
return True
12131243

12141244

@@ -1226,14 +1256,12 @@ def read(
12261256
if not endswith(path, self.known_suffixes):
12271257
return None
12281258

1229-
verts, faces, features = _load_ply(
1230-
f=path, path_manager=path_manager, return_vertex_colors=True
1231-
)
1259+
verts, faces, features = _load_ply(f=path, path_manager=path_manager)
12321260
verts = verts.to(device)
1233-
if features is None:
1234-
pointcloud = Pointclouds(points=[verts])
1235-
else:
1236-
pointcloud = Pointclouds(points=[verts], features=[features.to(device)])
1261+
if features is not None:
1262+
features = [features.to(device)]
1263+
1264+
pointcloud = Pointclouds(points=[verts], features=features)
12371265
return pointcloud
12381266

12391267
def save(
@@ -1249,13 +1277,14 @@ def save(
12491277
return False
12501278

12511279
points = data.points_list()[0]
1252-
features = data.features_list()[0]
1280+
features = data.features_packed()
1281+
12531282
with _open_file(path, path_manager, "wb") as f:
12541283
_save_ply(
12551284
f=f,
12561285
verts=points,
12571286
verts_colors=features,
1258-
verts_normals=torch.FloatTensor([]),
1287+
verts_normals=None,
12591288
faces=None,
12601289
ascii=binary is False,
12611290
decimal_places=decimal_places,

tests/test_io_ply.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3+
import itertools
34
import struct
45
import unittest
56
from io import BytesIO, StringIO
@@ -12,7 +13,8 @@
1213
from iopath.common.file_io import PathManager
1314
from pytorch3d.io import IO
1415
from pytorch3d.io.ply_io import load_ply, save_ply
15-
from pytorch3d.structures import Pointclouds
16+
from pytorch3d.renderer.mesh import TexturesVertex
17+
from pytorch3d.structures import Meshes, Pointclouds
1618
from pytorch3d.utils import torus
1719

1820

@@ -189,6 +191,57 @@ def test_pluggable_load_cube(self):
189191
):
190192
io.load_mesh(f3.name)
191193

194+
def test_save_too_many_colors(self):
195+
verts = torch.tensor(
196+
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
197+
)
198+
faces = torch.tensor([[0, 1, 2], [0, 2, 3]])
199+
vert_colors = torch.rand((4, 7))
200+
texture_with_seven_colors = TexturesVertex(verts_features=[vert_colors])
201+
202+
mesh = Meshes(
203+
verts=[verts],
204+
faces=[faces],
205+
textures=texture_with_seven_colors,
206+
)
207+
208+
io = IO()
209+
msg = "Texture will not be saved as it has 7 colors, not 3."
210+
with NamedTemporaryFile(mode="w", suffix=".ply") as f:
211+
with self.assertWarnsRegex(UserWarning, msg):
212+
io.save_mesh(mesh.cuda(), f.name)
213+
214+
def test_save_load_meshes(self):
215+
verts = torch.tensor(
216+
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
217+
)
218+
faces = torch.tensor([[0, 1, 2], [0, 2, 3]])
219+
vert_colors = torch.rand_like(verts)
220+
texture = TexturesVertex(verts_features=[vert_colors])
221+
222+
for do_textures in itertools.product([True, False]):
223+
mesh = Meshes(
224+
verts=[verts],
225+
faces=[faces],
226+
textures=texture if do_textures else None,
227+
)
228+
device = torch.device("cuda:0")
229+
230+
io = IO()
231+
with NamedTemporaryFile(mode="w", suffix=".ply") as f:
232+
io.save_mesh(mesh.cuda(), f.name)
233+
f.flush()
234+
mesh2 = io.load_mesh(f.name, device=device)
235+
self.assertEqual(mesh2.device, device)
236+
mesh2 = mesh2.cpu()
237+
self.assertClose(mesh2.verts_padded(), mesh.verts_padded())
238+
self.assertClose(mesh2.faces_padded(), mesh.faces_padded())
239+
if do_textures:
240+
self.assertIsInstance(mesh2.textures, TexturesVertex)
241+
self.assertClose(mesh2.textures.verts_features_list()[0], vert_colors)
242+
else:
243+
self.assertIsNone(mesh2.textures)
244+
192245
def test_save_ply_invalid_shapes(self):
193246
# Invalid vertices shape
194247
with self.assertRaises(ValueError) as error:

0 commit comments

Comments
 (0)