Skip to content

Commit 6e25fe8

Browse files
bottlerfacebook-github-bot
authored andcommitted
visualize_reconstruction fixes
Summary: Various fixes to get visualize_reconstruction running, and an interactive test for it. Reviewed By: kjchalup Differential Revision: D39286691 fbshipit-source-id: 88735034cc01736b24735bcb024577e6ab7ed336
1 parent 34ad77b commit 6e25fe8

File tree

8 files changed

+125
-58
lines changed

8 files changed

+125
-58
lines changed

projects/implicitron_trainer/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ If you have a custom `experiment.py` script (as in the Option 2 above), replace
6666
To run training, pass a yaml config file, followed by a list of overridden arguments.
6767
For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run:
6868
```shell
69-
dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
69+
dataset_args=data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
7070
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf \
7171
$dataset_args.dataset_root=<DATASET_ROOT> $dataset_args.category='skateboard' \
7272
$dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR>
@@ -87,7 +87,7 @@ To run evaluation on the latest checkpoint after (or during) training, simply ad
8787

8888
E.g. for executing the evaluation on the NeRF skateboard sequence, you can run:
8989
```shell
90-
dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
90+
dataset_args=data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
9191
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf \
9292
$dataset_args.dataset_root=<CO3D_DATASET_ROOT> $dataset_args.category='skateboard' \
9393
$dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True

projects/implicitron_trainer/tests/test_experiment.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,7 @@
1313
from omegaconf import OmegaConf
1414

1515
from .. import experiment
16-
from .utils import intercept_logs
17-
18-
19-
def interactive_testing_requested() -> bool:
20-
"""
21-
Certain tests are only useful when run interactively, and so are not regularly run.
22-
These are activated by this funciton returning True, which the user requests by
23-
setting the environment variable `PYTORCH3D_INTERACTIVE_TESTING` to 1.
24-
"""
25-
return os.environ.get("PYTORCH3D_INTERACTIVE_TESTING", "") == "1"
16+
from .utils import interactive_testing_requested, intercept_logs
2617

2718

2819
internal = os.environ.get("FB_TEST", False)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
import unittest
9+
10+
from .. import visualize_reconstruction
11+
from .utils import interactive_testing_requested
12+
13+
internal = os.environ.get("FB_TEST", False)
14+
15+
16+
class TestVisualize(unittest.TestCase):
17+
def test_from_defaults(self):
18+
if not interactive_testing_requested():
19+
return
20+
checkpoint_dir = os.environ["exp_dir"]
21+
argv = [
22+
f"exp_dir={checkpoint_dir}",
23+
"n_eval_cameras=40",
24+
"render_size=[64,64]",
25+
"video_size=[256,256]",
26+
]
27+
visualize_reconstruction.main(argv)

projects/implicitron_trainer/tests/utils.py

+10
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import contextlib
88
import logging
9+
import os
910
import re
1011

1112

@@ -28,3 +29,12 @@ def filter(self, record):
2829
yield intercepted_messages
2930
finally:
3031
logger.removeFilter(interceptor)
32+
33+
34+
def interactive_testing_requested() -> bool:
35+
"""
36+
Certain tests are only useful when run interactively, and so are not regularly run.
37+
These are activated by this funciton returning True, which the user requests by
38+
setting the environment variable `PYTORCH3D_INTERACTIVE_TESTING` to 1.
39+
"""
40+
return os.environ.get("PYTORCH3D_INTERACTIVE_TESTING", "") == "1"

projects/implicitron_trainer/visualize_reconstruction.py

+36-22
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
"""Script to visualize a previously trained model. Example call:
8+
"""
9+
Script to visualize a previously trained model. Example call:
910
10-
projects/implicitron_trainer/visualize_reconstruction.py
11-
exp_dir='./exps/checkpoint_dir' visdom_show_preds=True visdom_port=8097
11+
pytorch3d_implicitron_visualizer \
12+
exp_dir='./exps/checkpoint_dir' visdom_show_preds=True visdom_port=8097 \
1213
n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]"
1314
"""
1415

@@ -18,9 +19,9 @@
1819

1920
import numpy as np
2021
import torch
21-
from omegaconf import OmegaConf
22-
from pytorch3d.implicitron.models.visualization import render_flyaround
23-
from pytorch3d.implicitron.tools.configurable import get_default_args
22+
from omegaconf import DictConfig, OmegaConf
23+
from pytorch3d.implicitron.models.visualization.render_flyaround import render_flyaround
24+
from pytorch3d.implicitron.tools.config import enable_get_default_args, get_default_args
2425

2526
from .experiment import Experiment
2627

@@ -38,7 +39,7 @@ def visualize_reconstruction(
3839
visdom_server: str = "http://127.0.0.1",
3940
visdom_port: int = 8097,
4041
visdom_env: Optional[str] = None,
41-
):
42+
) -> None:
4243
"""
4344
Given an `exp_dir` containing a trained Implicitron model, generates videos consisting
4445
of renderes of sequences from the dataset used to train and evaluate the trained
@@ -76,22 +77,27 @@ def visualize_reconstruction(
7677
config = _get_config_from_experiment_directory(exp_dir)
7778
config.exp_dir = exp_dir
7879
# important so that the CO3D dataset gets loaded in full
79-
dataset_args = (
80-
config.data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
81-
)
82-
dataset_args.test_on_train = False
80+
data_source_args = config.data_source_ImplicitronDataSource_args
81+
if "dataset_map_provider_JsonIndexDatasetMapProvider_args" in data_source_args:
82+
dataset_args = (
83+
data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
84+
)
85+
dataset_args.test_on_train = False
86+
if restrict_sequence_name is not None:
87+
dataset_args.restrict_sequence_name = restrict_sequence_name
88+
8389
# Set the rendering image size
8490
model_factory_args = config.model_factory_ImplicitronModelFactory_args
91+
model_factory_args.force_resume = True
8592
model_args = model_factory_args.model_GenericModel_args
8693
model_args.render_image_width = render_size[0]
8794
model_args.render_image_height = render_size[1]
88-
if restrict_sequence_name is not None:
89-
dataset_args.restrict_sequence_name = restrict_sequence_name
9095

9196
# Load the previously trained model
92-
experiment = Experiment(config)
93-
model = experiment.model_factory(force_resume=True)
94-
model.cuda()
97+
experiment = Experiment(**config)
98+
model = experiment.model_factory(exp_dir=exp_dir)
99+
device = torch.device("cuda")
100+
model.to(device)
95101
model.eval()
96102

97103
# Setup the dataset
@@ -101,6 +107,11 @@ def visualize_reconstruction(
101107
if dataset is None:
102108
raise ValueError(f"{split} dataset not provided")
103109

110+
if visdom_env is None:
111+
visdom_env = (
112+
"visualizer_" + config.training_loop_ImplicitronTrainingLoop_args.visdom_env
113+
)
114+
104115
# iterate over the sequences in the dataset
105116
for sequence_name in dataset.sequence_names():
106117
with torch.no_grad():
@@ -114,23 +125,26 @@ def visualize_reconstruction(
114125
n_flyaround_poses=n_eval_cameras,
115126
visdom_server=visdom_server,
116127
visdom_port=visdom_port,
117-
visdom_environment=f"visualizer_{config.visdom_env}"
118-
if visdom_env is None
119-
else visdom_env,
128+
visdom_environment=visdom_env,
120129
video_resize=video_size,
130+
device=device,
121131
)
122132

123133

124-
def _get_config_from_experiment_directory(experiment_directory):
134+
enable_get_default_args(visualize_reconstruction)
135+
136+
137+
def _get_config_from_experiment_directory(experiment_directory) -> DictConfig:
125138
cfg_file = os.path.join(experiment_directory, "expconfig.yaml")
126139
config = OmegaConf.load(cfg_file)
140+
# pyre-ignore[7]
127141
return config
128142

129143

130-
def main(argv):
144+
def main(argv) -> None:
131145
# automatically parses arguments of visualize_reconstruction
132146
cfg = OmegaConf.create(get_default_args(visualize_reconstruction))
133-
cfg.update(OmegaConf.from_cli())
147+
cfg.update(OmegaConf.from_cli(argv))
134148
with torch.no_grad():
135149
visualize_reconstruction(**cfg)
136150

pytorch3d/implicitron/dataset/single_sequence_dataset.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# provide data for a single scene.
1010

1111
from dataclasses import field
12-
from typing import Iterable, List, Optional
12+
from typing import Iterable, Iterator, List, Optional, Tuple
1313

1414
import numpy as np
1515
import torch
@@ -46,6 +46,12 @@ def sequence_names(self) -> Iterable[str]:
4646
def __len__(self) -> int:
4747
return len(self.poses)
4848

49+
def sequence_frames_in_order(
50+
self, seq_name: str
51+
) -> Iterator[Tuple[float, int, int]]:
52+
for i in range(len(self)):
53+
yield (0.0, i, i)
54+
4955
def __getitem__(self, index) -> FrameData:
5056
if index >= len(self):
5157
raise IndexError(f"index {index} out of range {len(self)}")

pytorch3d/implicitron/models/visualization/render_flyaround.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def render_flyaround(
6161
"depths_render",
6262
"_all_source_images",
6363
),
64-
):
64+
) -> None:
6565
"""
6666
Uses `model` to generate a video consisting of renders of a scene imaged from
6767
a camera flying around the scene. The scene is specified with the `dataset` object and
@@ -133,6 +133,7 @@ def render_flyaround(
133133
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
134134
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
135135
assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
136+
# pyre-ignore[6]
136137
sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
137138
logger.info(f"Sequence set = {sequence_set_name}.")
138139
train_cameras = train_data.camera
@@ -209,7 +210,7 @@ def render_flyaround(
209210

210211
def _load_whole_dataset(
211212
dataset: torch.utils.data.Dataset, idx: Sequence[int], num_workers: int = 10
212-
):
213+
) -> FrameData:
213214
load_all_dataloader = torch.utils.data.DataLoader(
214215
torch.utils.data.Subset(dataset, idx),
215216
batch_size=len(idx),
@@ -220,7 +221,7 @@ def _load_whole_dataset(
220221
return next(iter(load_all_dataloader))
221222

222223

223-
def _images_from_preds(preds: Dict[str, Any]):
224+
def _images_from_preds(preds: Dict[str, Any]) -> Dict[str, torch.Tensor]:
224225
imout = {}
225226
for k in (
226227
"image_rgb",
@@ -253,7 +254,7 @@ def _images_from_preds(preds: Dict[str, Any]):
253254
return imout
254255

255256

256-
def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]):
257+
def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]) -> torch.Tensor:
257258
ba = ims.shape[0]
258259
H = int(np.ceil(np.sqrt(ba)))
259260
W = H
@@ -281,7 +282,7 @@ def _show_predictions(
281282
),
282283
n_samples=10,
283284
one_image_width=200,
284-
):
285+
) -> None:
285286
"""Given a list of predictions visualize them into a single image using visdom."""
286287
assert isinstance(preds, list)
287288

@@ -329,7 +330,7 @@ def _generate_prediction_videos(
329330
video_path: str = "/tmp/video",
330331
video_frames_dir: Optional[str] = None,
331332
resize: Optional[Tuple[int, int]] = None,
332-
):
333+
) -> None:
333334
"""Given a list of predictions create and visualize rotating videos of the
334335
objects using visdom.
335336
"""
@@ -359,7 +360,7 @@ def _generate_prediction_videos(
359360
)
360361

361362
for k in predicted_keys:
362-
vws[k].get_video(quiet=True)
363+
vws[k].get_video()
363364
logger.info(f"Generated {vws[k].out_path}.")
364365
if viz is not None:
365366
viz.video(

pytorch3d/implicitron/tools/video_writer.py

+34-16
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import os
88
import shutil
9+
import subprocess
910
import tempfile
1011
import warnings
1112
from typing import Optional, Tuple, Union
@@ -15,6 +16,7 @@
1516
import numpy as np
1617
from PIL import Image
1718

19+
_DEFAULT_FFMPEG = os.environ.get("FFMPEG", "ffmpeg")
1820

1921
matplotlib.use("Agg")
2022

@@ -27,13 +29,13 @@ class VideoWriter:
2729
def __init__(
2830
self,
2931
cache_dir: Optional[str] = None,
30-
ffmpeg_bin: str = "ffmpeg",
32+
ffmpeg_bin: str = _DEFAULT_FFMPEG,
3133
out_path: str = "/tmp/video.mp4",
3234
fps: int = 20,
3335
output_format: str = "visdom",
3436
rmdir_allowed: bool = False,
3537
**kwargs,
36-
):
38+
) -> None:
3739
"""
3840
Args:
3941
cache_dir: A directory for storing the video frames. If `None`,
@@ -74,7 +76,7 @@ def write_frame(
7476
self,
7577
frame: Union[matplotlib.figure.Figure, np.ndarray, Image.Image, str],
7678
resize: Optional[Union[float, Tuple[int, int]]] = None,
77-
):
79+
) -> None:
7880
"""
7981
Write a frame to the video.
8082
@@ -114,7 +116,7 @@ def write_frame(
114116
self.frames.append(outfile)
115117
self.frame_num += 1
116118

117-
def get_video(self, quiet: bool = True):
119+
def get_video(self) -> str:
118120
"""
119121
Generate the video from the written frames.
120122
@@ -127,23 +129,39 @@ def get_video(self, quiet: bool = True):
127129

128130
regexp = os.path.join(self.cache_dir, self.regexp)
129131

130-
if self.output_format == "visdom": # works for ppt too
131-
ffmcmd_ = (
132-
"%s -r %d -i %s -vcodec h264 -f mp4 \
133-
-y -crf 18 -b 2000k -pix_fmt yuv420p '%s'"
134-
% (self.ffmpeg_bin, self.fps, regexp, self.out_path)
132+
if shutil.which(self.ffmpeg_bin) is None:
133+
raise ValueError(
134+
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
135+
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
135136
)
136-
else:
137-
raise ValueError("no such output type %s" % str(self.output_format))
138137

139-
if quiet:
140-
ffmcmd_ += " > /dev/null 2>&1"
138+
if self.output_format == "visdom": # works for ppt too
139+
args = [
140+
self.ffmpeg_bin,
141+
"-r",
142+
str(self.fps),
143+
"-i",
144+
regexp,
145+
"-vcodec",
146+
"h264",
147+
"-f",
148+
"mp4",
149+
"-y",
150+
"-crf",
151+
"18",
152+
"-b",
153+
"2000k",
154+
"-pix_fmt",
155+
"yuv420p",
156+
self.out_path,
157+
]
158+
159+
subprocess.check_call(args)
141160
else:
142-
print(ffmcmd_)
143-
os.system(ffmcmd_)
161+
raise ValueError("no such output type %s" % str(self.output_format))
144162

145163
return self.out_path
146164

147-
def __del__(self):
165+
def __del__(self) -> None:
148166
if self.tmp_dir is not None:
149167
self.tmp_dir.cleanup()

0 commit comments

Comments
 (0)