Skip to content

Commit fd992b1

Browse files
committed
allow saving vertex normal in save_obj
1 parent 1af6bf4 commit fd992b1

File tree

2 files changed

+179
-1
lines changed

2 files changed

+179
-1
lines changed

pytorch3d/io/obj_io.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,8 @@ def save_obj(
684684
decimal_places: Optional[int] = None,
685685
path_manager: Optional[PathManager] = None,
686686
*,
687+
verts_normals: Optional[torch.Tensor] = None,
688+
faces_normals: Optional[torch.Tensor] = None,
687689
verts_uvs: Optional[torch.Tensor] = None,
688690
faces_uvs: Optional[torch.Tensor] = None,
689691
texture_map: Optional[torch.Tensor] = None,
@@ -698,6 +700,9 @@ def save_obj(
698700
decimal_places: Number of decimal places for saving.
699701
path_manager: Optional PathManager for interpreting f if
700702
it is a str.
703+
verts_normals: FloatTensor of shape (V, 3) giving the normal per vertex.
704+
faces_normals: LongTensor of shape (F, 3) giving the index into verts_normals
705+
for each vertex in the face.
701706
verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex.
702707
faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for
703708
each vertex in the face.
@@ -712,6 +717,14 @@ def save_obj(
712717
if len(faces) and (faces.dim() != 2 or faces.size(1) != 3):
713718
message = "'faces' should either be empty or of shape (num_faces, 3)."
714719
raise ValueError(message)
720+
721+
if faces_normals is not None and (faces_normals.dim() != 2 or faces_normals.size(1) != 3):
722+
message = "'faces_normals' should either be empty or of shape (num_faces, 3)."
723+
raise ValueError(message)
724+
725+
if verts_normals is not None and (verts_normals.dim() != 2 or verts_normals.size(1) != 3):
726+
message = "'verts_normals' should either be empty or of shape (num_verts, 3)."
727+
raise ValueError(message)
715728

716729
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
717730
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
@@ -728,6 +741,7 @@ def save_obj(
728741
if path_manager is None:
729742
path_manager = PathManager()
730743

744+
save_normals = all([n is not None for n in [verts_normals, faces_normals]])
731745
save_texture = all([t is not None for t in [faces_uvs, verts_uvs, texture_map]])
732746
output_path = Path(f)
733747

@@ -742,9 +756,12 @@ def save_obj(
742756
verts,
743757
faces,
744758
decimal_places,
759+
verts_normals=verts_normals,
760+
faces_normals=faces_normals,
745761
verts_uvs=verts_uvs,
746762
faces_uvs=faces_uvs,
747763
save_texture=save_texture,
764+
save_normals=save_normals,
748765
)
749766

750767
# Save the .mtl and .png files associated with the texture
@@ -777,9 +794,12 @@ def _save(
777794
faces,
778795
decimal_places: Optional[int] = None,
779796
*,
797+
verts_normals: Optional[torch.Tensor] = None,
798+
faces_normals: Optional[torch.Tensor] = None,
780799
verts_uvs: Optional[torch.Tensor] = None,
781800
faces_uvs: Optional[torch.Tensor] = None,
782801
save_texture: bool = False,
802+
save_normals: bool = False,
783803
) -> None:
784804

785805
if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
@@ -809,6 +829,25 @@ def _save(
809829
vert = [float_str % verts[i, j] for j in range(D)]
810830
lines += "v %s\n" % " ".join(vert)
811831

832+
if save_normals:
833+
if faces_normals is not None and (faces_normals.dim() != 2 or faces_normals.size(1) != 3):
834+
message = "'faces_normals' should either be empty or of shape (num_faces, 3)."
835+
raise ValueError(message)
836+
837+
if verts_normals is not None and (verts_normals.dim() != 2 or verts_normals.size(1) != 3):
838+
message = "'verts_normals' should either be empty or of shape (num_verts, 3)."
839+
raise ValueError(message)
840+
841+
# pyre-fixme[16] # undefined attribute cpu
842+
verts_normals, faces_normals = verts_normals.cpu(), faces_normals.cpu()
843+
844+
# Save verts normals after verts
845+
if len(verts_normals):
846+
V, D = verts_normals.shape
847+
for i in range(V):
848+
normal = [float_str % verts_normals[i, j] for j in range(D)]
849+
lines += "vn %s\n" % " ".join(normal)
850+
812851
if save_texture:
813852
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
814853
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
@@ -834,7 +873,22 @@ def _save(
834873
if len(faces):
835874
F, P = faces.shape
836875
for i in range(F):
837-
if save_texture:
876+
if save_texture and save_normals:
877+
# Format faces as {verts_idx}/{verts_uvs_idx}/{verts_normals_idx}
878+
face = [
879+
"%d/%d/%d" % (
880+
faces[i, j] + 1,
881+
faces_uvs[i, j] + 1,
882+
faces_normals[i, j] + 1,
883+
)
884+
for j in range(P)
885+
]
886+
elif save_normals:
887+
# Format faces as {verts_idx}//{verts_normals_idx}
888+
face = [
889+
"%d//%d" % (faces[i, j] + 1, faces_normals[i, j] + 1) for j in range(P)
890+
]
891+
elif save_texture:
838892
# Format faces as {verts_idx}/{verts_uvs_idx}
839893
face = [
840894
"%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P)

tests/test_io_obj.py

+124
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,52 @@ def check_item(x, y):
895895
with self.assertRaisesRegex(ValueError, "same type of texture"):
896896
join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas])
897897

898+
def test_save_obj_with_normal(self):
899+
verts = torch.tensor(
900+
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
901+
dtype=torch.float32,
902+
)
903+
faces = torch.tensor(
904+
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
905+
)
906+
verts_normals = torch.tensor(
907+
[[0.02, 0.5, 0.73], [0.3, 0.03, 0.361], [0.32, 0.12, 0.47], [0.36, 0.17, 0.9]],
908+
dtype=torch.float32,
909+
)
910+
faces_normals = faces
911+
912+
with TemporaryDirectory() as temp_dir:
913+
obj_file = os.path.join(temp_dir, "mesh.obj")
914+
save_obj(
915+
obj_file,
916+
verts,
917+
faces,
918+
decimal_places=2,
919+
verts_normals=verts_normals,
920+
faces_normals=faces_normals,
921+
)
922+
923+
expected_obj_file = "\n".join(
924+
[
925+
"v 0.01 0.20 0.30",
926+
"v 0.20 0.03 0.41",
927+
"v 0.30 0.40 0.05",
928+
"v 0.60 0.70 0.80",
929+
"vn 0.02 0.50 0.73",
930+
"vn 0.30 0.03 0.36",
931+
"vn 0.32 0.12 0.47",
932+
"vn 0.36 0.17 0.90",
933+
"f 1//1 3//3 2//2",
934+
"f 1//1 2//2 3//3",
935+
"f 4//4 3//3 2//2",
936+
"f 4//4 2//2 1//1",
937+
]
938+
)
939+
940+
# Check the obj file is saved correctly
941+
actual_file = open(obj_file, "r")
942+
self.assertEqual(actual_file.read(), expected_obj_file)
943+
898944
def test_save_obj_with_texture(self):
899945
verts = torch.tensor(
900946
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
@@ -962,6 +1008,84 @@ def test_save_obj_with_texture(self):
9621008
texture_image = load_rgb_image("mesh.png", temp_dir)
9631009
self.assertClose(texture_image, texture_map)
9641010

1011+
def test_save_obj_with_normal_and_texture(self):
1012+
verts = torch.tensor(
1013+
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
1014+
dtype=torch.float32,
1015+
)
1016+
faces = torch.tensor(
1017+
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
1018+
)
1019+
verts_normals = torch.tensor(
1020+
[[0.02, 0.5, 0.73], [0.3, 0.03, 0.361], [0.32, 0.12, 0.47], [0.36, 0.17, 0.9]],
1021+
dtype=torch.float32,
1022+
)
1023+
faces_normals = faces
1024+
verts_uvs = torch.tensor(
1025+
[[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]],
1026+
dtype=torch.float32,
1027+
)
1028+
faces_uvs = faces
1029+
texture_map = torch.randint(size=(2, 2, 3), high=255) / 255.0
1030+
1031+
with TemporaryDirectory() as temp_dir:
1032+
obj_file = os.path.join(temp_dir, "mesh.obj")
1033+
save_obj(
1034+
obj_file,
1035+
verts,
1036+
faces,
1037+
decimal_places=2,
1038+
verts_normals=verts_normals,
1039+
faces_normals=faces_normals,
1040+
verts_uvs=verts_uvs,
1041+
faces_uvs=faces_uvs,
1042+
texture_map=texture_map,
1043+
)
1044+
1045+
expected_obj_file = "\n".join(
1046+
[
1047+
"",
1048+
"mtllib mesh.mtl",
1049+
"usemtl mesh",
1050+
"",
1051+
"v 0.01 0.20 0.30",
1052+
"v 0.20 0.03 0.41",
1053+
"v 0.30 0.40 0.05",
1054+
"v 0.60 0.70 0.80",
1055+
"vn 0.02 0.50 0.73",
1056+
"vn 0.30 0.03 0.36",
1057+
"vn 0.32 0.12 0.47",
1058+
"vn 0.36 0.17 0.90",
1059+
"vt 0.02 0.50",
1060+
"vt 0.30 0.03",
1061+
"vt 0.32 0.12",
1062+
"vt 0.36 0.17",
1063+
"f 1/1/1 3/3/3 2/2/2",
1064+
"f 1/1/1 2/2/2 3/3/3",
1065+
"f 4/4/4 3/3/3 2/2/2",
1066+
"f 4/4/4 2/2/2 1/1/1",
1067+
]
1068+
)
1069+
expected_mtl_file = "\n".join(["newmtl mesh", "map_Kd mesh.png", ""])
1070+
1071+
# Check there are only 3 files in the temp dir
1072+
tempfiles = ["mesh.obj", "mesh.png", "mesh.mtl"]
1073+
tempfiles_dir = os.listdir(temp_dir)
1074+
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
1075+
1076+
# Check the obj file is saved correctly
1077+
actual_file = open(obj_file, "r")
1078+
self.assertEqual(actual_file.read(), expected_obj_file)
1079+
1080+
# Check the mtl file is saved correctly
1081+
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
1082+
mtl_file = open(mtl_file_name, "r")
1083+
self.assertEqual(mtl_file.read(), expected_mtl_file)
1084+
1085+
# Check the texture image file is saved correctly
1086+
texture_image = load_rgb_image("mesh.png", temp_dir)
1087+
self.assertClose(texture_image, texture_map)
1088+
9651089
def test_save_obj_with_texture_errors(self):
9661090
verts = torch.tensor(
9671091
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],

0 commit comments

Comments
 (0)