20
20
import torch
21
21
from iopath .common .file_io import PathManager
22
22
from pytorch3d .io .utils import _check_faces_indices , _make_tensor , _open_file
23
+ from pytorch3d .renderer import TexturesVertex
23
24
from pytorch3d .structures import Meshes , Pointclouds
24
25
25
26
from .pluggable_formats import (
@@ -66,7 +67,7 @@ class _PlyElementType:
66
67
def __init__ (self , name : str , count : int ):
67
68
self .name = name
68
69
self .count = count
69
- self .properties = []
70
+ self .properties : List [ _Property ] = []
70
71
71
72
def add_property (
72
73
self , name : str , data_type : str , list_size_type : Optional [str ] = None
@@ -142,7 +143,7 @@ def __init__(self, f):
142
143
if f .readline () not in [b"ply\n " , b"ply\r \n " , "ply\n " ]:
143
144
raise ValueError ("Invalid file header." )
144
145
seen_format = False
145
- self .elements = []
146
+ self .elements : List [ _PlyElementType ] = []
146
147
self .obj_info = []
147
148
while True :
148
149
line = f .readline ()
@@ -891,8 +892,8 @@ def _get_verts(
891
892
892
893
893
894
def _load_ply (
894
- f , * , path_manager : PathManager , return_vertex_colors : bool = False
895
- ) -> Tuple [torch .Tensor , torch .Tensor , Optional [torch .Tensor ]]:
895
+ f , * , path_manager : PathManager
896
+ ) -> Tuple [torch .Tensor , Optional [ torch .Tensor ] , Optional [torch .Tensor ]]:
896
897
"""
897
898
Load the data from a .ply file.
898
899
@@ -903,12 +904,11 @@ def _load_ply(
903
904
ply format, then a text stream is not supported.
904
905
It is easiest to use a binary stream in all cases.
905
906
path_manager: PathManager for loading if f is a str.
906
- return_vertex_colors: whether to return vertex colors.
907
907
908
908
Returns:
909
909
verts: FloatTensor of shape (V, 3).
910
910
faces: None or LongTensor of vertex indices, shape (F, 3).
911
- vertex_colors: None or FloatTensor of shape (V, 3), only if requested
911
+ vertex_colors: None or FloatTensor of shape (V, 3).
912
912
"""
913
913
header , elements = _load_ply_raw (f , path_manager = path_manager )
914
914
@@ -950,16 +950,17 @@ def _load_ply(
950
950
if faces is not None :
951
951
_check_faces_indices (faces , max_index = verts .shape [0 ])
952
952
953
- if return_vertex_colors :
954
- return verts , faces , vertex_colors
955
- return verts , faces , None
953
+ return verts , faces , vertex_colors
956
954
957
955
958
956
def load_ply (
959
957
f , * , path_manager : Optional [PathManager ] = None
960
958
) -> Tuple [torch .Tensor , torch .Tensor ]:
961
959
"""
962
- Load the data from a .ply file.
960
+ Load the verts and faces from a .ply file.
961
+ Note that the preferred way to load data from such a file
962
+ is to use the IO.load_mesh and IO.load_pointcloud functions,
963
+ which can read more of the data.
963
964
964
965
Example .ply file format:
965
966
@@ -1016,8 +1017,8 @@ def _save_ply(
1016
1017
* ,
1017
1018
verts : torch .Tensor ,
1018
1019
faces : Optional [torch .LongTensor ],
1019
- verts_normals : torch .Tensor ,
1020
- verts_colors : torch .Tensor ,
1020
+ verts_normals : Optional [ torch .Tensor ] ,
1021
+ verts_colors : Optional [ torch .Tensor ] ,
1021
1022
ascii : bool ,
1022
1023
decimal_places : Optional [int ] = None ,
1023
1024
) -> None :
@@ -1029,16 +1030,16 @@ def _save_ply(
1029
1030
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
1030
1031
faces: LongTensor of shape (F, 3) giving faces.
1031
1032
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
1033
+ verts_colors: FloatTensor of shape (V, 3) giving vertex colors.
1032
1034
ascii: (bool) whether to use the ascii ply format.
1033
1035
decimal_places: Number of decimal places for saving if ascii=True.
1034
1036
"""
1035
1037
assert not len (verts ) or (verts .dim () == 2 and verts .size (1 ) == 3 )
1036
- if faces is not None :
1037
- assert not len (faces ) or (faces .dim () == 2 and faces .size (1 ) == 3 )
1038
- assert not len (verts_normals ) or (
1038
+ assert faces is None or not len (faces ) or (faces .dim () == 2 and faces .size (1 ) == 3 )
1039
+ assert verts_normals is None or (
1039
1040
verts_normals .dim () == 2 and verts_normals .size (1 ) == 3
1040
1041
)
1041
- assert not len ( verts_colors ) or (
1042
+ assert verts_colors is None or (
1042
1043
verts_colors .dim () == 2 and verts_colors .size (1 ) == 3
1043
1044
)
1044
1045
@@ -1052,11 +1053,11 @@ def _save_ply(
1052
1053
f .write (b"property float x\n " )
1053
1054
f .write (b"property float y\n " )
1054
1055
f .write (b"property float z\n " )
1055
- if verts_normals . numel () > 0 :
1056
+ if verts_normals is not None :
1056
1057
f .write (b"property float nx\n " )
1057
1058
f .write (b"property float ny\n " )
1058
1059
f .write (b"property float nz\n " )
1059
- if verts_colors . numel () > 0 :
1060
+ if verts_colors is not None :
1060
1061
f .write (b"property float red\n " )
1061
1062
f .write (b"property float green\n " )
1062
1063
f .write (b"property float blue\n " )
@@ -1069,7 +1070,13 @@ def _save_ply(
1069
1070
warnings .warn ("Empty 'verts' provided" )
1070
1071
return
1071
1072
1072
- vert_data = torch .cat ((verts , verts_normals , verts_colors ), dim = 1 ).detach ().numpy ()
1073
+ verts_tensors = [verts ]
1074
+ if verts_normals is not None :
1075
+ verts_tensors .append (verts_normals )
1076
+ if verts_colors is not None :
1077
+ verts_tensors .append (verts_colors )
1078
+
1079
+ vert_data = torch .cat (verts_tensors , dim = 1 ).detach ().cpu ().numpy ()
1073
1080
if ascii :
1074
1081
if decimal_places is None :
1075
1082
float_str = "%f"
@@ -1085,7 +1092,7 @@ def _save_ply(
1085
1092
vert_data .tofile (f )
1086
1093
1087
1094
if faces is not None :
1088
- faces_array = faces .detach ().numpy ()
1095
+ faces_array = faces .detach ().cpu (). numpy ()
1089
1096
1090
1097
_check_faces_indices (faces , max_index = verts .shape [0 ])
1091
1098
@@ -1125,12 +1132,6 @@ def save_ply(
1125
1132
1126
1133
"""
1127
1134
1128
- verts_normals = (
1129
- torch .tensor ([], dtype = torch .float32 , device = verts .device )
1130
- if verts_normals is None
1131
- else verts_normals
1132
- )
1133
-
1134
1135
if len (verts ) and not (verts .dim () == 2 and verts .size (1 ) == 3 ):
1135
1136
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
1136
1137
raise ValueError (message )
@@ -1143,16 +1144,18 @@ def save_ply(
1143
1144
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
1144
1145
raise ValueError (message )
1145
1146
1146
- if len (verts_normals ) and not (
1147
- verts_normals .dim () == 2
1148
- and verts_normals .size (1 ) == 3
1149
- and verts_normals .size (0 ) == verts .size (0 )
1147
+ if (
1148
+ verts_normals is not None
1149
+ and len (verts_normals )
1150
+ and not (
1151
+ verts_normals .dim () == 2
1152
+ and verts_normals .size (1 ) == 3
1153
+ and verts_normals .size (0 ) == verts .size (0 )
1154
+ )
1150
1155
):
1151
1156
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
1152
1157
raise ValueError (message )
1153
1158
1154
- verts_colors = torch .FloatTensor ([])
1155
-
1156
1159
if path_manager is None :
1157
1160
path_manager = PathManager ()
1158
1161
with _open_file (f , path_manager , "wb" ) as f :
@@ -1161,7 +1164,7 @@ def save_ply(
1161
1164
verts = verts ,
1162
1165
faces = faces ,
1163
1166
verts_normals = verts_normals ,
1164
- verts_colors = verts_colors ,
1167
+ verts_colors = None ,
1165
1168
ascii = ascii ,
1166
1169
decimal_places = decimal_places ,
1167
1170
)
@@ -1182,8 +1185,19 @@ def read(
1182
1185
if not endswith (path , self .known_suffixes ):
1183
1186
return None
1184
1187
1185
- verts , faces = load_ply (f = path , path_manager = path_manager )
1186
- mesh = Meshes (verts = [verts .to (device )], faces = [faces .to (device )])
1188
+ verts , faces , verts_colors = _load_ply (f = path , path_manager = path_manager )
1189
+ if faces is None :
1190
+ faces = torch .zeros (0 , 3 , dtype = torch .int64 )
1191
+
1192
+ textures = None
1193
+ if include_textures and verts_colors is not None :
1194
+ textures = TexturesVertex ([verts_colors .to (device )])
1195
+
1196
+ mesh = Meshes (
1197
+ verts = [verts .to (device )],
1198
+ faces = [faces .to (device )],
1199
+ textures = textures ,
1200
+ )
1187
1201
return mesh
1188
1202
1189
1203
def save (
@@ -1201,14 +1215,30 @@ def save(
1201
1215
# TODO: normals are not saved. We only want to save them if they already exist.
1202
1216
verts = data .verts_list ()[0 ]
1203
1217
faces = data .faces_list ()[0 ]
1204
- save_ply (
1205
- f = path ,
1206
- verts = verts ,
1207
- faces = faces ,
1208
- ascii = binary is False ,
1209
- decimal_places = decimal_places ,
1210
- path_manager = path_manager ,
1211
- )
1218
+
1219
+ if isinstance (data .textures , TexturesVertex ):
1220
+ mesh_verts_colors = data .textures .verts_features_list ()[0 ]
1221
+ n_colors = mesh_verts_colors .shape [1 ]
1222
+ if n_colors == 3 :
1223
+ verts_colors = mesh_verts_colors
1224
+ else :
1225
+ warnings .warn (
1226
+ f"Texture will not be saved as it has { n_colors } colors, not 3."
1227
+ )
1228
+ verts_colors = None
1229
+ else :
1230
+ verts_colors = None
1231
+
1232
+ with _open_file (path , path_manager , "wb" ) as f :
1233
+ _save_ply (
1234
+ f = f ,
1235
+ verts = verts ,
1236
+ faces = faces ,
1237
+ verts_colors = verts_colors ,
1238
+ verts_normals = None ,
1239
+ ascii = binary is False ,
1240
+ decimal_places = decimal_places ,
1241
+ )
1212
1242
return True
1213
1243
1214
1244
@@ -1226,14 +1256,12 @@ def read(
1226
1256
if not endswith (path , self .known_suffixes ):
1227
1257
return None
1228
1258
1229
- verts , faces , features = _load_ply (
1230
- f = path , path_manager = path_manager , return_vertex_colors = True
1231
- )
1259
+ verts , faces , features = _load_ply (f = path , path_manager = path_manager )
1232
1260
verts = verts .to (device )
1233
- if features is None :
1234
- pointcloud = Pointclouds ( points = [ verts ])
1235
- else :
1236
- pointcloud = Pointclouds (points = [verts ], features = [ features . to ( device )] )
1261
+ if features is not None :
1262
+ features = [ features . to ( device )]
1263
+
1264
+ pointcloud = Pointclouds (points = [verts ], features = features )
1237
1265
return pointcloud
1238
1266
1239
1267
def save (
@@ -1249,13 +1277,14 @@ def save(
1249
1277
return False
1250
1278
1251
1279
points = data .points_list ()[0 ]
1252
- features = data .features_list ()[0 ]
1280
+ features = data .features_packed ()
1281
+
1253
1282
with _open_file (path , path_manager , "wb" ) as f :
1254
1283
_save_ply (
1255
1284
f = f ,
1256
1285
verts = points ,
1257
1286
verts_colors = features ,
1258
- verts_normals = torch . FloatTensor ([]) ,
1287
+ verts_normals = None ,
1259
1288
faces = None ,
1260
1289
ascii = binary is False ,
1261
1290
decimal_places = decimal_places ,
0 commit comments