5
5
# This source code is licensed under the BSD-style license found in the
6
6
# LICENSE file in the root directory of this source tree.
7
7
8
- """Script to visualize a previously trained model. Example call:
8
+ """
9
+ Script to visualize a previously trained model. Example call:
9
10
10
- projects/implicitron_trainer/visualize_reconstruction.py
11
- exp_dir='./exps/checkpoint_dir' visdom_show_preds=True visdom_port=8097
11
+ pytorch3d_implicitron_visualizer \
12
+ exp_dir='./exps/checkpoint_dir' visdom_show_preds=True visdom_port=8097 \
12
13
n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]"
13
14
"""
14
15
18
19
19
20
import numpy as np
20
21
import torch
21
- from omegaconf import OmegaConf
22
- from pytorch3d .implicitron .models .visualization import render_flyaround
23
- from pytorch3d .implicitron .tools .configurable import get_default_args
22
+ from omegaconf import DictConfig , OmegaConf
23
+ from pytorch3d .implicitron .models .visualization . render_flyaround import render_flyaround
24
+ from pytorch3d .implicitron .tools .config import enable_get_default_args , get_default_args
24
25
25
26
from .experiment import Experiment
26
27
@@ -38,7 +39,7 @@ def visualize_reconstruction(
38
39
visdom_server : str = "http://127.0.0.1" ,
39
40
visdom_port : int = 8097 ,
40
41
visdom_env : Optional [str ] = None ,
41
- ):
42
+ ) -> None :
42
43
"""
43
44
Given an `exp_dir` containing a trained Implicitron model, generates videos consisting
44
45
of renderes of sequences from the dataset used to train and evaluate the trained
@@ -76,22 +77,27 @@ def visualize_reconstruction(
76
77
config = _get_config_from_experiment_directory (exp_dir )
77
78
config .exp_dir = exp_dir
78
79
# important so that the CO3D dataset gets loaded in full
79
- dataset_args = (
80
- config .data_source_args .dataset_map_provider_JsonIndexDatasetMapProvider_args
81
- )
82
- dataset_args .test_on_train = False
80
+ data_source_args = config .data_source_ImplicitronDataSource_args
81
+ if "dataset_map_provider_JsonIndexDatasetMapProvider_args" in data_source_args :
82
+ dataset_args = (
83
+ data_source_args .dataset_map_provider_JsonIndexDatasetMapProvider_args
84
+ )
85
+ dataset_args .test_on_train = False
86
+ if restrict_sequence_name is not None :
87
+ dataset_args .restrict_sequence_name = restrict_sequence_name
88
+
83
89
# Set the rendering image size
84
90
model_factory_args = config .model_factory_ImplicitronModelFactory_args
91
+ model_factory_args .force_resume = True
85
92
model_args = model_factory_args .model_GenericModel_args
86
93
model_args .render_image_width = render_size [0 ]
87
94
model_args .render_image_height = render_size [1 ]
88
- if restrict_sequence_name is not None :
89
- dataset_args .restrict_sequence_name = restrict_sequence_name
90
95
91
96
# Load the previously trained model
92
- experiment = Experiment (config )
93
- model = experiment .model_factory (force_resume = True )
94
- model .cuda ()
97
+ experiment = Experiment (** config )
98
+ model = experiment .model_factory (exp_dir = exp_dir )
99
+ device = torch .device ("cuda" )
100
+ model .to (device )
95
101
model .eval ()
96
102
97
103
# Setup the dataset
@@ -101,6 +107,11 @@ def visualize_reconstruction(
101
107
if dataset is None :
102
108
raise ValueError (f"{ split } dataset not provided" )
103
109
110
+ if visdom_env is None :
111
+ visdom_env = (
112
+ "visualizer_" + config .training_loop_ImplicitronTrainingLoop_args .visdom_env
113
+ )
114
+
104
115
# iterate over the sequences in the dataset
105
116
for sequence_name in dataset .sequence_names ():
106
117
with torch .no_grad ():
@@ -114,23 +125,26 @@ def visualize_reconstruction(
114
125
n_flyaround_poses = n_eval_cameras ,
115
126
visdom_server = visdom_server ,
116
127
visdom_port = visdom_port ,
117
- visdom_environment = f"visualizer_{ config .visdom_env } "
118
- if visdom_env is None
119
- else visdom_env ,
128
+ visdom_environment = visdom_env ,
120
129
video_resize = video_size ,
130
+ device = device ,
121
131
)
122
132
123
133
124
- def _get_config_from_experiment_directory (experiment_directory ):
134
+ enable_get_default_args (visualize_reconstruction )
135
+
136
+
137
+ def _get_config_from_experiment_directory (experiment_directory ) -> DictConfig :
125
138
cfg_file = os .path .join (experiment_directory , "expconfig.yaml" )
126
139
config = OmegaConf .load (cfg_file )
140
+ # pyre-ignore[7]
127
141
return config
128
142
129
143
130
- def main (argv ):
144
+ def main (argv ) -> None :
131
145
# automatically parses arguments of visualize_reconstruction
132
146
cfg = OmegaConf .create (get_default_args (visualize_reconstruction ))
133
- cfg .update (OmegaConf .from_cli ())
147
+ cfg .update (OmegaConf .from_cli (argv ))
134
148
with torch .no_grad ():
135
149
visualize_reconstruction (** cfg )
136
150
0 commit comments