11
11
import plotly .graph_objects as go
12
12
import torch
13
13
from plotly .subplots import make_subplots
14
- from pytorch3d .renderer import TexturesVertex
14
+ from pytorch3d .renderer import TexturesVertex , RayBundle , ray_bundle_to_ray_points
15
15
from pytorch3d .renderer .camera_utils import camera_to_eye_at_up
16
16
from pytorch3d .renderer .cameras import CamerasBase
17
17
from pytorch3d .structures import Meshes , Pointclouds , join_meshes_as_scene
18
18
19
19
20
+ Struct = Union [CamerasBase , Meshes , Pointclouds , RayBundle ]
21
+
22
+
23
+ def _get_struct_len (struct : Struct ): # pragma: no cover
24
+ """
25
+ Returns the length (usually corresponds to the batch size) of the input structure.
26
+ """
27
+ return len (struct .directions ) if isinstance (struct , RayBundle ) else len (struct )
28
+
29
+
20
30
def get_camera_wireframe (scale : float = 0.3 ): # pragma: no cover
21
31
"""
22
32
Returns a wireframe of a 3D line-plot of a camera symbol.
@@ -55,18 +65,22 @@ class Lighting(NamedTuple): # pragma: no cover
55
65
56
66
57
67
def plot_scene (
58
- plots : Dict [str , Dict [str , Union [ Pointclouds , Meshes , CamerasBase ] ]],
68
+ plots : Dict [str , Dict [str , Struct ]],
59
69
* ,
60
70
viewpoint_cameras : Optional [CamerasBase ] = None ,
61
71
ncols : int = 1 ,
62
72
camera_scale : float = 0.3 ,
63
73
pointcloud_max_points : int = 20000 ,
64
74
pointcloud_marker_size : int = 1 ,
75
+ raybundle_max_rays : int = 20000 ,
76
+ raybundle_max_points_per_ray : int = 1000 ,
77
+ raybundle_ray_point_marker_size : int = 1 ,
78
+ raybundle_ray_line_width : int = 1 ,
65
79
** kwargs ,
66
80
): # pragma: no cover
67
81
"""
68
- Main function to visualize Meshes, Cameras and Pointclouds .
69
- Plots input Pointclouds , Meshes, and Cameras data into named subplots,
82
+ Main function to visualize Cameras, Meshes, Pointclouds, and RayBundle .
83
+ Plots input Cameras , Meshes, Pointclouds, and RayBundle data into named subplots,
70
84
with named traces based on the dictionary keys. Cameras are
71
85
rendered at the camera center location using a wireframe.
72
86
@@ -87,6 +101,13 @@ def plot_scene(
87
101
pointcloud_max_points is used.
88
102
pointcloud_marker_size: the size of the points rendered by plotly
89
103
when plotting a pointcloud.
104
+ raybundle_max_rays: maximum number of rays of a RayBundle to visualize. Randomly
105
+ subsamples without replacement in case the number of rays is bigger than max_rays.
106
+ raybundle_max_points_per_ray: the maximum number of points per ray in RayBundle
107
+ to visualize. If more are present, a random sample of size
108
+ max_points_per_ray is used.
109
+ raybundle_ray_point_marker_size: the size of the ray points of a plotted RayBundle
110
+ raybundle_ray_line_width: the width of the plotted rays of a RayBundle
90
111
**kwargs: Accepts lighting (a Lighting object) and any of the args xaxis,
91
112
yaxis and zaxis which Plotly's scene accepts. Accepts axis_args,
92
113
which is an AxisArgs object that is applied to all 3 axes.
@@ -186,6 +207,18 @@ def plot_scene(
186
207
The above example will render one subplot with the mesh object
187
208
and two cameras.
188
209
210
+ RayBundle visualization is also supproted:
211
+ ..code-block::python
212
+ cameras = PerspectiveCameras(...)
213
+ ray_bundle = RayBundle(origins=..., lengths=..., directions=..., xys=...)
214
+ fig = plot_scene({
215
+ "subplot1_title": {
216
+ "ray_bundle_trace_title": ray_bundle,
217
+ "cameras_trace_title": cameras,
218
+ },
219
+ })
220
+ fig.show()
221
+
189
222
For an example of using kwargs, see below:
190
223
..code-block::python
191
224
mesh = ...
@@ -264,11 +297,22 @@ def plot_scene(
264
297
_add_camera_trace (
265
298
fig , struct , trace_name , subplot_idx , ncols , camera_scale
266
299
)
300
+ elif isinstance (struct , RayBundle ):
301
+ _add_ray_bundle_trace (
302
+ fig ,
303
+ struct ,
304
+ trace_name ,
305
+ subplot_idx ,
306
+ ncols ,
307
+ raybundle_max_rays ,
308
+ raybundle_max_points_per_ray ,
309
+ raybundle_ray_point_marker_size ,
310
+ raybundle_ray_line_width ,
311
+ )
267
312
else :
268
313
raise ValueError (
269
- "struct {} is not a Cameras, Meshes or Pointclouds object" .format (
270
- struct
271
- )
314
+ "struct {} is not a Cameras, Meshes, Pointclouds," .format (struct )
315
+ + " or RayBundle object."
272
316
)
273
317
274
318
# Ensure update for every subplot.
@@ -329,7 +373,8 @@ def plot_scene(
329
373
330
374
def plot_batch_individually (
331
375
batched_structs : Union [
332
- List [Union [Meshes , Pointclouds , CamerasBase ]], Meshes , Pointclouds , CamerasBase
376
+ List [Struct ],
377
+ Struct ,
333
378
],
334
379
* ,
335
380
viewpoint_cameras : Optional [CamerasBase ] = None ,
@@ -340,26 +385,27 @@ def plot_batch_individually(
340
385
): # pragma: no cover
341
386
"""
342
387
This is a higher level plotting function than plot_scene, for plotting
343
- Cameras, Meshes and Pointclouds in simple cases. The simplest use is to plot a
344
- single Cameras, Meshes or Pointclouds object, where you just pass it in as a
345
- one element list. This will plot each batch element in a separate subplot.
388
+ Cameras, Meshes, Pointclouds, and RayBundle in simple cases. The simplest use
389
+ is to plot a single Cameras, Meshes, Pointclouds, or a RayBundle object,
390
+ where you just pass it in as a one element list. This will plot each batch
391
+ element in a separate subplot.
346
392
347
- More generally, you can supply multiple Cameras, Meshes or Pointclouds
393
+ More generally, you can supply multiple Cameras, Meshes, Pointclouds, or RayBundle
348
394
having the same batch size `n`. In this case, there will be `n` subplots,
349
395
each depicting the corresponding batch element of all the inputs.
350
396
351
- In addition, you can include Cameras, Meshes and Pointclouds of size 1 in
397
+ In addition, you can include Cameras, Meshes, Pointclouds, or RayBundle of size 1 in
352
398
the input. These will either be rendered in the first subplot
353
399
(if extend_struct is False), or in every subplot.
354
400
355
401
Args:
356
- batched_structs: a list of Cameras, Meshes and/or Pointclouds to be rendered.
357
- Each structure's corresponding batch element will be plotted in
358
- a single subplot, resulting in n subplots for a batch of size n.
402
+ batched_structs: a list of Cameras, Meshes, Pointclouds, and RayBundle
403
+ to be rendered. Each structure's corresponding batch element will be
404
+ plotted in a single subplot, resulting in n subplots for a batch of size n.
359
405
Every struct should either have the same batch size or be of batch size 1.
360
406
See extend_struct and the description above for how batch size 1 structs
361
- are handled. Also accepts a single Cameras, Meshes or Pointclouds object,
362
- which will have each individual element plotted in its own subplot.
407
+ are handled. Also accepts a single Cameras, Meshes, Pointclouds, and RayBundle
408
+ object, which will have each individual element plotted in its own subplot.
363
409
viewpoint_cameras: an instance of a Cameras object providing a location
364
410
to view the plotly plot from. If the batch size is equal
365
411
to the number of subplots, it is a one to one mapping.
@@ -407,13 +453,14 @@ def plot_batch_individually(
407
453
return
408
454
max_size = 0
409
455
if isinstance (batched_structs , list ):
410
- max_size = max (len (s ) for s in batched_structs )
456
+ max_size = max (_get_struct_len (s ) for s in batched_structs )
411
457
for struct in batched_structs :
412
- if len (struct ) not in (1 , max_size ):
413
- msg = "invalid batch size {} provided: {}" .format (len (struct ), struct )
458
+ struct_len = _get_struct_len (struct )
459
+ if struct_len not in (1 , max_size ):
460
+ msg = "invalid batch size {} provided: {}" .format (struct_len , struct )
414
461
raise ValueError (msg )
415
462
else :
416
- max_size = len (batched_structs )
463
+ max_size = _get_struct_len (batched_structs )
417
464
418
465
if max_size == 0 :
419
466
msg = "No data is provided with at least one element"
@@ -437,7 +484,8 @@ def plot_batch_individually(
437
484
if isinstance (batched_structs , list ):
438
485
for i , batched_struct in enumerate (batched_structs ):
439
486
# check for whether this struct needs to be extended
440
- if i >= len (batched_struct ) and not extend_struct :
487
+ batched_struct_len = _get_struct_len (batched_struct )
488
+ if i >= batched_struct_len and not extend_struct :
441
489
continue
442
490
_add_struct_from_batch (
443
491
batched_struct , scene_num , subplot_title , scene_dictionary , i + 1
@@ -453,10 +501,10 @@ def plot_batch_individually(
453
501
454
502
455
503
def _add_struct_from_batch (
456
- batched_struct : Union [ CamerasBase , Meshes , Pointclouds ] ,
504
+ batched_struct : Struct ,
457
505
scene_num : int ,
458
506
subplot_title : str ,
459
- scene_dictionary : Dict [str , Dict [str , Union [ CamerasBase , Meshes , Pointclouds ] ]],
507
+ scene_dictionary : Dict [str , Dict [str , Struct ]],
460
508
trace_idx : int = 1 ,
461
509
): # pragma: no cover
462
510
"""
@@ -492,6 +540,15 @@ def _add_struct_from_batch(
492
540
# torch.Tensor, torch.nn.Module]` is not a function.
493
541
T = T [t_idx ].unsqueeze (0 )
494
542
struct = CamerasBase (device = batched_struct .device , R = R , T = T )
543
+ elif isinstance (batched_struct , RayBundle ):
544
+ # for RayBundle we treat the 1st dim as the batch index
545
+ struct_idx = min (scene_num , len (batched_struct .lengths ) - 1 )
546
+ struct = RayBundle (
547
+ ** {
548
+ attr : getattr (batched_struct , attr )[struct_idx ]
549
+ for attr in ["origins" , "directions" , "lengths" , "xys" ]
550
+ }
551
+ )
495
552
else : # batched meshes and pointclouds are indexable
496
553
struct_idx = min (scene_num , len (batched_struct ) - 1 )
497
554
struct = batched_struct [struct_idx ]
@@ -702,6 +759,138 @@ def _add_camera_trace(
702
759
_update_axes_bounds (verts_center , max_expand , current_layout )
703
760
704
761
762
+ def _add_ray_bundle_trace (
763
+ fig : go .Figure ,
764
+ ray_bundle : RayBundle ,
765
+ trace_name : str ,
766
+ subplot_idx : int ,
767
+ ncols : int ,
768
+ max_rays : int ,
769
+ max_points_per_ray : int ,
770
+ marker_size : int ,
771
+ line_width : int ,
772
+ ): # pragma: no cover
773
+ """
774
+ Adds a trace rendering a RayBundle object to the passed in figure, with
775
+ a given name and in a specific subplot.
776
+
777
+ Args:
778
+ fig: plotly figure to add the trace within.
779
+ cameras: the Cameras object to render. It can be batched.
780
+ trace_name: name to label the trace with.
781
+ subplot_idx: identifies the subplot, with 0 being the top left.
782
+ ncols: the number of subplots per row.
783
+ max_rays: maximum number of plotted rays in total. Randomly subsamples
784
+ without replacement in case the number of rays is bigger than max_rays.
785
+ max_points_per_ray: maximum number of points plotted per ray.
786
+ marker_size: the size of the ray point markers.
787
+ line_width: the width of the ray lines.
788
+ """
789
+
790
+ n_pts_per_ray = ray_bundle .lengths .shape [- 1 ]
791
+ n_rays = ray_bundle .lengths .shape [:- 1 ].numel () # pyre-ignore[16]
792
+
793
+ # flatten all batches of rays into a single big bundle
794
+ ray_bundle_flat = RayBundle (
795
+ ** {
796
+ attr : torch .flatten (getattr (ray_bundle , attr ), start_dim = 0 , end_dim = - 2 )
797
+ for attr in ["origins" , "directions" , "lengths" , "xys" ]
798
+ }
799
+ )
800
+
801
+ # subsample the rays (if needed)
802
+ if n_rays > max_rays :
803
+ indices_rays = torch .randperm (n_rays )[:max_rays ]
804
+ ray_bundle_flat = RayBundle (
805
+ ** {
806
+ attr : getattr (ray_bundle_flat , attr )[indices_rays ]
807
+ for attr in ["origins" , "directions" , "lengths" , "xys" ]
808
+ }
809
+ )
810
+
811
+ # make ray line endpoints
812
+ min_max_ray_depth = torch .stack (
813
+ [
814
+ ray_bundle_flat .lengths .min (dim = 1 ).values , # pyre-ignore[16]
815
+ ray_bundle_flat .lengths .max (dim = 1 ).values ,
816
+ ],
817
+ dim = - 1 ,
818
+ )
819
+ ray_lines_endpoints = ray_bundle_to_ray_points (
820
+ ray_bundle_flat ._replace (lengths = min_max_ray_depth )
821
+ )
822
+
823
+ # make the ray lines for plotly plotting
824
+ nan_tensor = torch .Tensor ([[float ("NaN" )] * 3 ])
825
+ ray_lines = torch .empty (size = (1 , 3 ))
826
+ for ray_line in ray_lines_endpoints :
827
+ # We combine the ray lines into a single tensor to plot them in a
828
+ # single trace. The NaNs are inserted between sets of ray lines
829
+ # so that the lines drawn by Plotly are not drawn between
830
+ # lines that belong to different rays.
831
+ ray_lines = torch .cat ((ray_lines , nan_tensor , ray_line ))
832
+ x , y , z = ray_lines .detach ().cpu ().numpy ().T .astype (float )
833
+ row , col = subplot_idx // ncols + 1 , subplot_idx % ncols + 1
834
+ fig .add_trace (
835
+ go .Scatter3d (
836
+ x = x ,
837
+ y = y ,
838
+ z = z ,
839
+ marker = {"size" : 0.1 },
840
+ line = {"width" : line_width },
841
+ name = trace_name ,
842
+ ),
843
+ row = row ,
844
+ col = col ,
845
+ )
846
+
847
+ # subsample the ray points (if needed)
848
+ if n_pts_per_ray > max_points_per_ray :
849
+ indices_ray_pts = torch .cat (
850
+ [
851
+ torch .randperm (n_pts_per_ray )[:max_points_per_ray ] + ri * n_pts_per_ray
852
+ for ri in range (ray_bundle_flat .lengths .shape [0 ])
853
+ ]
854
+ )
855
+ ray_bundle_flat = ray_bundle_flat ._replace (
856
+ lengths = ray_bundle_flat .lengths .reshape (- 1 )[indices_ray_pts ].reshape (
857
+ ray_bundle_flat .lengths .shape [0 ], - 1
858
+ )
859
+ )
860
+
861
+ # plot the ray points
862
+ ray_points = (
863
+ ray_bundle_to_ray_points (ray_bundle_flat )
864
+ .view (- 1 , 3 )
865
+ .detach ()
866
+ .cpu ()
867
+ .numpy ()
868
+ .astype (float )
869
+ )
870
+ fig .add_trace (
871
+ go .Scatter3d (
872
+ x = ray_points [:, 0 ],
873
+ y = ray_points [:, 1 ],
874
+ z = ray_points [:, 2 ],
875
+ mode = "markers" ,
876
+ name = trace_name + "_points" ,
877
+ marker = {"size" : marker_size },
878
+ ),
879
+ row = row ,
880
+ col = col ,
881
+ )
882
+
883
+ # Access the current subplot's scene configuration
884
+ plot_scene = "scene" + str (subplot_idx + 1 )
885
+ current_layout = fig ["layout" ][plot_scene ]
886
+
887
+ # update the bounds of the axes for the current trace
888
+ all_ray_points = ray_bundle_to_ray_points (ray_bundle ).view (- 1 , 3 )
889
+ ray_points_center = all_ray_points .mean (dim = 0 )
890
+ max_expand = (all_ray_points .max (0 )[0 ] - all_ray_points .min (0 )[0 ]).max ().item ()
891
+ _update_axes_bounds (ray_points_center , float (max_expand ), current_layout )
892
+
893
+
705
894
def _gen_fig_with_subplots (
706
895
batch_size : int , ncols : int , subplot_titles : List [str ]
707
896
): # pragma: no cover
0 commit comments