Skip to content

Commit 4426a9d

Browse files
davnov134facebook-github-bot
authored andcommitted
RayBundle visualization
Summary: Extends plotly_vis to visualize `RayBundle`s. Reviewed By: patricklabatut Differential Revision: D29014098 fbshipit-source-id: 4dee426510a1fa53d4afefbe1bcdd003684c9932
1 parent 62ff77b commit 4426a9d

File tree

1 file changed

+214
-25
lines changed

1 file changed

+214
-25
lines changed

pytorch3d/vis/plotly_vis.py

+214-25
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,22 @@
1111
import plotly.graph_objects as go
1212
import torch
1313
from plotly.subplots import make_subplots
14-
from pytorch3d.renderer import TexturesVertex
14+
from pytorch3d.renderer import TexturesVertex, RayBundle, ray_bundle_to_ray_points
1515
from pytorch3d.renderer.camera_utils import camera_to_eye_at_up
1616
from pytorch3d.renderer.cameras import CamerasBase
1717
from pytorch3d.structures import Meshes, Pointclouds, join_meshes_as_scene
1818

1919

20+
Struct = Union[CamerasBase, Meshes, Pointclouds, RayBundle]
21+
22+
23+
def _get_struct_len(struct: Struct): # pragma: no cover
24+
"""
25+
Returns the length (usually corresponds to the batch size) of the input structure.
26+
"""
27+
return len(struct.directions) if isinstance(struct, RayBundle) else len(struct)
28+
29+
2030
def get_camera_wireframe(scale: float = 0.3): # pragma: no cover
2131
"""
2232
Returns a wireframe of a 3D line-plot of a camera symbol.
@@ -55,18 +65,22 @@ class Lighting(NamedTuple): # pragma: no cover
5565

5666

5767
def plot_scene(
58-
plots: Dict[str, Dict[str, Union[Pointclouds, Meshes, CamerasBase]]],
68+
plots: Dict[str, Dict[str, Struct]],
5969
*,
6070
viewpoint_cameras: Optional[CamerasBase] = None,
6171
ncols: int = 1,
6272
camera_scale: float = 0.3,
6373
pointcloud_max_points: int = 20000,
6474
pointcloud_marker_size: int = 1,
75+
raybundle_max_rays: int = 20000,
76+
raybundle_max_points_per_ray: int = 1000,
77+
raybundle_ray_point_marker_size: int = 1,
78+
raybundle_ray_line_width: int = 1,
6579
**kwargs,
6680
): # pragma: no cover
6781
"""
68-
Main function to visualize Meshes, Cameras and Pointclouds.
69-
Plots input Pointclouds, Meshes, and Cameras data into named subplots,
82+
Main function to visualize Cameras, Meshes, Pointclouds, and RayBundle.
83+
Plots input Cameras, Meshes, Pointclouds, and RayBundle data into named subplots,
7084
with named traces based on the dictionary keys. Cameras are
7185
rendered at the camera center location using a wireframe.
7286
@@ -87,6 +101,13 @@ def plot_scene(
87101
pointcloud_max_points is used.
88102
pointcloud_marker_size: the size of the points rendered by plotly
89103
when plotting a pointcloud.
104+
raybundle_max_rays: maximum number of rays of a RayBundle to visualize. Randomly
105+
subsamples without replacement in case the number of rays is bigger than max_rays.
106+
raybundle_max_points_per_ray: the maximum number of points per ray in RayBundle
107+
to visualize. If more are present, a random sample of size
108+
max_points_per_ray is used.
109+
raybundle_ray_point_marker_size: the size of the ray points of a plotted RayBundle
110+
raybundle_ray_line_width: the width of the plotted rays of a RayBundle
90111
**kwargs: Accepts lighting (a Lighting object) and any of the args xaxis,
91112
yaxis and zaxis which Plotly's scene accepts. Accepts axis_args,
92113
which is an AxisArgs object that is applied to all 3 axes.
@@ -186,6 +207,18 @@ def plot_scene(
186207
The above example will render one subplot with the mesh object
187208
and two cameras.
188209
210+
RayBundle visualization is also supproted:
211+
..code-block::python
212+
cameras = PerspectiveCameras(...)
213+
ray_bundle = RayBundle(origins=..., lengths=..., directions=..., xys=...)
214+
fig = plot_scene({
215+
"subplot1_title": {
216+
"ray_bundle_trace_title": ray_bundle,
217+
"cameras_trace_title": cameras,
218+
},
219+
})
220+
fig.show()
221+
189222
For an example of using kwargs, see below:
190223
..code-block::python
191224
mesh = ...
@@ -264,11 +297,22 @@ def plot_scene(
264297
_add_camera_trace(
265298
fig, struct, trace_name, subplot_idx, ncols, camera_scale
266299
)
300+
elif isinstance(struct, RayBundle):
301+
_add_ray_bundle_trace(
302+
fig,
303+
struct,
304+
trace_name,
305+
subplot_idx,
306+
ncols,
307+
raybundle_max_rays,
308+
raybundle_max_points_per_ray,
309+
raybundle_ray_point_marker_size,
310+
raybundle_ray_line_width,
311+
)
267312
else:
268313
raise ValueError(
269-
"struct {} is not a Cameras, Meshes or Pointclouds object".format(
270-
struct
271-
)
314+
"struct {} is not a Cameras, Meshes, Pointclouds,".format(struct)
315+
+ " or RayBundle object."
272316
)
273317

274318
# Ensure update for every subplot.
@@ -329,7 +373,8 @@ def plot_scene(
329373

330374
def plot_batch_individually(
331375
batched_structs: Union[
332-
List[Union[Meshes, Pointclouds, CamerasBase]], Meshes, Pointclouds, CamerasBase
376+
List[Struct],
377+
Struct,
333378
],
334379
*,
335380
viewpoint_cameras: Optional[CamerasBase] = None,
@@ -340,26 +385,27 @@ def plot_batch_individually(
340385
): # pragma: no cover
341386
"""
342387
This is a higher level plotting function than plot_scene, for plotting
343-
Cameras, Meshes and Pointclouds in simple cases. The simplest use is to plot a
344-
single Cameras, Meshes or Pointclouds object, where you just pass it in as a
345-
one element list. This will plot each batch element in a separate subplot.
388+
Cameras, Meshes, Pointclouds, and RayBundle in simple cases. The simplest use
389+
is to plot a single Cameras, Meshes, Pointclouds, or a RayBundle object,
390+
where you just pass it in as a one element list. This will plot each batch
391+
element in a separate subplot.
346392
347-
More generally, you can supply multiple Cameras, Meshes or Pointclouds
393+
More generally, you can supply multiple Cameras, Meshes, Pointclouds, or RayBundle
348394
having the same batch size `n`. In this case, there will be `n` subplots,
349395
each depicting the corresponding batch element of all the inputs.
350396
351-
In addition, you can include Cameras, Meshes and Pointclouds of size 1 in
397+
In addition, you can include Cameras, Meshes, Pointclouds, or RayBundle of size 1 in
352398
the input. These will either be rendered in the first subplot
353399
(if extend_struct is False), or in every subplot.
354400
355401
Args:
356-
batched_structs: a list of Cameras, Meshes and/or Pointclouds to be rendered.
357-
Each structure's corresponding batch element will be plotted in
358-
a single subplot, resulting in n subplots for a batch of size n.
402+
batched_structs: a list of Cameras, Meshes, Pointclouds, and RayBundle
403+
to be rendered. Each structure's corresponding batch element will be
404+
plotted in a single subplot, resulting in n subplots for a batch of size n.
359405
Every struct should either have the same batch size or be of batch size 1.
360406
See extend_struct and the description above for how batch size 1 structs
361-
are handled. Also accepts a single Cameras, Meshes or Pointclouds object,
362-
which will have each individual element plotted in its own subplot.
407+
are handled. Also accepts a single Cameras, Meshes, Pointclouds, and RayBundle
408+
object, which will have each individual element plotted in its own subplot.
363409
viewpoint_cameras: an instance of a Cameras object providing a location
364410
to view the plotly plot from. If the batch size is equal
365411
to the number of subplots, it is a one to one mapping.
@@ -407,13 +453,14 @@ def plot_batch_individually(
407453
return
408454
max_size = 0
409455
if isinstance(batched_structs, list):
410-
max_size = max(len(s) for s in batched_structs)
456+
max_size = max(_get_struct_len(s) for s in batched_structs)
411457
for struct in batched_structs:
412-
if len(struct) not in (1, max_size):
413-
msg = "invalid batch size {} provided: {}".format(len(struct), struct)
458+
struct_len = _get_struct_len(struct)
459+
if struct_len not in (1, max_size):
460+
msg = "invalid batch size {} provided: {}".format(struct_len, struct)
414461
raise ValueError(msg)
415462
else:
416-
max_size = len(batched_structs)
463+
max_size = _get_struct_len(batched_structs)
417464

418465
if max_size == 0:
419466
msg = "No data is provided with at least one element"
@@ -437,7 +484,8 @@ def plot_batch_individually(
437484
if isinstance(batched_structs, list):
438485
for i, batched_struct in enumerate(batched_structs):
439486
# check for whether this struct needs to be extended
440-
if i >= len(batched_struct) and not extend_struct:
487+
batched_struct_len = _get_struct_len(batched_struct)
488+
if i >= batched_struct_len and not extend_struct:
441489
continue
442490
_add_struct_from_batch(
443491
batched_struct, scene_num, subplot_title, scene_dictionary, i + 1
@@ -453,10 +501,10 @@ def plot_batch_individually(
453501

454502

455503
def _add_struct_from_batch(
456-
batched_struct: Union[CamerasBase, Meshes, Pointclouds],
504+
batched_struct: Struct,
457505
scene_num: int,
458506
subplot_title: str,
459-
scene_dictionary: Dict[str, Dict[str, Union[CamerasBase, Meshes, Pointclouds]]],
507+
scene_dictionary: Dict[str, Dict[str, Struct]],
460508
trace_idx: int = 1,
461509
): # pragma: no cover
462510
"""
@@ -492,6 +540,15 @@ def _add_struct_from_batch(
492540
# torch.Tensor, torch.nn.Module]` is not a function.
493541
T = T[t_idx].unsqueeze(0)
494542
struct = CamerasBase(device=batched_struct.device, R=R, T=T)
543+
elif isinstance(batched_struct, RayBundle):
544+
# for RayBundle we treat the 1st dim as the batch index
545+
struct_idx = min(scene_num, len(batched_struct.lengths) - 1)
546+
struct = RayBundle(
547+
**{
548+
attr: getattr(batched_struct, attr)[struct_idx]
549+
for attr in ["origins", "directions", "lengths", "xys"]
550+
}
551+
)
495552
else: # batched meshes and pointclouds are indexable
496553
struct_idx = min(scene_num, len(batched_struct) - 1)
497554
struct = batched_struct[struct_idx]
@@ -702,6 +759,138 @@ def _add_camera_trace(
702759
_update_axes_bounds(verts_center, max_expand, current_layout)
703760

704761

762+
def _add_ray_bundle_trace(
763+
fig: go.Figure,
764+
ray_bundle: RayBundle,
765+
trace_name: str,
766+
subplot_idx: int,
767+
ncols: int,
768+
max_rays: int,
769+
max_points_per_ray: int,
770+
marker_size: int,
771+
line_width: int,
772+
): # pragma: no cover
773+
"""
774+
Adds a trace rendering a RayBundle object to the passed in figure, with
775+
a given name and in a specific subplot.
776+
777+
Args:
778+
fig: plotly figure to add the trace within.
779+
cameras: the Cameras object to render. It can be batched.
780+
trace_name: name to label the trace with.
781+
subplot_idx: identifies the subplot, with 0 being the top left.
782+
ncols: the number of subplots per row.
783+
max_rays: maximum number of plotted rays in total. Randomly subsamples
784+
without replacement in case the number of rays is bigger than max_rays.
785+
max_points_per_ray: maximum number of points plotted per ray.
786+
marker_size: the size of the ray point markers.
787+
line_width: the width of the ray lines.
788+
"""
789+
790+
n_pts_per_ray = ray_bundle.lengths.shape[-1]
791+
n_rays = ray_bundle.lengths.shape[:-1].numel() # pyre-ignore[16]
792+
793+
# flatten all batches of rays into a single big bundle
794+
ray_bundle_flat = RayBundle(
795+
**{
796+
attr: torch.flatten(getattr(ray_bundle, attr), start_dim=0, end_dim=-2)
797+
for attr in ["origins", "directions", "lengths", "xys"]
798+
}
799+
)
800+
801+
# subsample the rays (if needed)
802+
if n_rays > max_rays:
803+
indices_rays = torch.randperm(n_rays)[:max_rays]
804+
ray_bundle_flat = RayBundle(
805+
**{
806+
attr: getattr(ray_bundle_flat, attr)[indices_rays]
807+
for attr in ["origins", "directions", "lengths", "xys"]
808+
}
809+
)
810+
811+
# make ray line endpoints
812+
min_max_ray_depth = torch.stack(
813+
[
814+
ray_bundle_flat.lengths.min(dim=1).values, # pyre-ignore[16]
815+
ray_bundle_flat.lengths.max(dim=1).values,
816+
],
817+
dim=-1,
818+
)
819+
ray_lines_endpoints = ray_bundle_to_ray_points(
820+
ray_bundle_flat._replace(lengths=min_max_ray_depth)
821+
)
822+
823+
# make the ray lines for plotly plotting
824+
nan_tensor = torch.Tensor([[float("NaN")] * 3])
825+
ray_lines = torch.empty(size=(1, 3))
826+
for ray_line in ray_lines_endpoints:
827+
# We combine the ray lines into a single tensor to plot them in a
828+
# single trace. The NaNs are inserted between sets of ray lines
829+
# so that the lines drawn by Plotly are not drawn between
830+
# lines that belong to different rays.
831+
ray_lines = torch.cat((ray_lines, nan_tensor, ray_line))
832+
x, y, z = ray_lines.detach().cpu().numpy().T.astype(float)
833+
row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
834+
fig.add_trace(
835+
go.Scatter3d(
836+
x=x,
837+
y=y,
838+
z=z,
839+
marker={"size": 0.1},
840+
line={"width": line_width},
841+
name=trace_name,
842+
),
843+
row=row,
844+
col=col,
845+
)
846+
847+
# subsample the ray points (if needed)
848+
if n_pts_per_ray > max_points_per_ray:
849+
indices_ray_pts = torch.cat(
850+
[
851+
torch.randperm(n_pts_per_ray)[:max_points_per_ray] + ri * n_pts_per_ray
852+
for ri in range(ray_bundle_flat.lengths.shape[0])
853+
]
854+
)
855+
ray_bundle_flat = ray_bundle_flat._replace(
856+
lengths=ray_bundle_flat.lengths.reshape(-1)[indices_ray_pts].reshape(
857+
ray_bundle_flat.lengths.shape[0], -1
858+
)
859+
)
860+
861+
# plot the ray points
862+
ray_points = (
863+
ray_bundle_to_ray_points(ray_bundle_flat)
864+
.view(-1, 3)
865+
.detach()
866+
.cpu()
867+
.numpy()
868+
.astype(float)
869+
)
870+
fig.add_trace(
871+
go.Scatter3d(
872+
x=ray_points[:, 0],
873+
y=ray_points[:, 1],
874+
z=ray_points[:, 2],
875+
mode="markers",
876+
name=trace_name + "_points",
877+
marker={"size": marker_size},
878+
),
879+
row=row,
880+
col=col,
881+
)
882+
883+
# Access the current subplot's scene configuration
884+
plot_scene = "scene" + str(subplot_idx + 1)
885+
current_layout = fig["layout"][plot_scene]
886+
887+
# update the bounds of the axes for the current trace
888+
all_ray_points = ray_bundle_to_ray_points(ray_bundle).view(-1, 3)
889+
ray_points_center = all_ray_points.mean(dim=0)
890+
max_expand = (all_ray_points.max(0)[0] - all_ray_points.min(0)[0]).max().item()
891+
_update_axes_bounds(ray_points_center, float(max_expand), current_layout)
892+
893+
705894
def _gen_fig_with_subplots(
706895
batch_size: int, ncols: int, subplot_titles: List[str]
707896
): # pragma: no cover

0 commit comments

Comments
 (0)