Skip to content

Commit 6ae6ff9

Browse files
include TexturesUV in IO.save_mesh(x.obj)
Summary: Added export of UV textures to IO.save_mesh in Pytorch3d MeshObjFormat now passes verts_uv, faces_uv, and texture_map as input to save_obj TODO: check if TexturesUV.verts_uv_list or TexturesUV.verts_uv_padded() should be passed to save_obj IO.save_mesh(obj_file, meshes, decimal_places=2) should be IO().save_mesh(obj_file, meshes, decimal_places=2) Reviewed By: bottler Differential Revision: D39617441 fbshipit-source-id: 4628b7f26f70e38c65f235852b990c8edb0ded23
1 parent 305cf32 commit 6ae6ff9

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

pytorch3d/io/obj_io.py

+13
Original file line numberDiff line numberDiff line change
@@ -334,12 +334,25 @@ def save(
334334

335335
verts = data.verts_list()[0]
336336
faces = data.faces_list()[0]
337+
338+
verts_uvs: Optional[torch.Tensor] = None
339+
faces_uvs: Optional[torch.Tensor] = None
340+
texture_map: Optional[torch.Tensor] = None
341+
342+
if isinstance(data.textures, TexturesUV):
343+
verts_uvs = data.textures.verts_uvs_padded()[0]
344+
faces_uvs = data.textures.faces_uvs_padded()[0]
345+
texture_map = data.textures.maps_padded()[0]
346+
337347
save_obj(
338348
f=path,
339349
verts=verts,
340350
faces=faces,
341351
decimal_places=decimal_places,
342352
path_manager=path_manager,
353+
verts_uvs=verts_uvs,
354+
faces_uvs=faces_uvs,
355+
texture_map=texture_map,
343356
)
344357
return True
345358

tests/test_io_obj.py

+62
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,68 @@ def test_save_obj_with_texture_errors(self):
10501050
texture_map=texture_map[..., 1], # Incorrect shape
10511051
)
10521052

1053+
def test_save_obj_with_texture_IO(self):
1054+
verts = torch.tensor(
1055+
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
1056+
dtype=torch.float32,
1057+
)
1058+
faces = torch.tensor(
1059+
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
1060+
)
1061+
verts_uvs = torch.tensor(
1062+
[[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]],
1063+
dtype=torch.float32,
1064+
)
1065+
faces_uvs = faces
1066+
texture_map = torch.randint(size=(2, 2, 3), high=255) / 255.0
1067+
1068+
with TemporaryDirectory() as temp_dir:
1069+
obj_file = os.path.join(temp_dir, "mesh.obj")
1070+
textures_uv = TexturesUV([texture_map], [faces_uvs], [verts_uvs])
1071+
test_mesh = Meshes(verts=[verts], faces=[faces], textures=textures_uv)
1072+
1073+
IO().save_mesh(data=test_mesh, path=obj_file, decimal_places=2)
1074+
1075+
expected_obj_file = "\n".join(
1076+
[
1077+
"",
1078+
"mtllib mesh.mtl",
1079+
"usemtl mesh",
1080+
"",
1081+
"v 0.01 0.20 0.30",
1082+
"v 0.20 0.03 0.41",
1083+
"v 0.30 0.40 0.05",
1084+
"v 0.60 0.70 0.80",
1085+
"vt 0.02 0.50",
1086+
"vt 0.30 0.03",
1087+
"vt 0.32 0.12",
1088+
"vt 0.36 0.17",
1089+
"f 1/1 3/3 2/2",
1090+
"f 1/1 2/2 3/3",
1091+
"f 4/4 3/3 2/2",
1092+
"f 4/4 2/2 1/1",
1093+
]
1094+
)
1095+
expected_mtl_file = "\n".join(["newmtl mesh", "map_Kd mesh.png", ""])
1096+
1097+
# Check there are only 3 files in the temp dir
1098+
tempfiles = ["mesh.obj", "mesh.png", "mesh.mtl"]
1099+
tempfiles_dir = os.listdir(temp_dir)
1100+
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
1101+
1102+
# Check the obj file is saved correctly
1103+
actual_file = open(obj_file, "r")
1104+
self.assertEqual(actual_file.read(), expected_obj_file)
1105+
1106+
# Check the mtl file is saved correctly
1107+
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
1108+
mtl_file = open(mtl_file_name, "r")
1109+
self.assertEqual(mtl_file.read(), expected_mtl_file)
1110+
1111+
# Check the texture image file is saved correctly
1112+
texture_image = load_rgb_image("mesh.png", temp_dir)
1113+
self.assertClose(texture_image, texture_map)
1114+
10531115
@staticmethod
10541116
def _bm_save_obj(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
10551117
return lambda: save_obj(StringIO(), verts, faces, decimal_places)

0 commit comments

Comments
 (0)