1
- from typing import List , Optional , Tuple , Union
1
+ # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
2
+ # Copyright 2024-2025 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # --------------------------------------------------------------------------
16
+ # More information and citation instructions are available on the
17
+ # Marigold project website: https://marigoldcomputervision.github.io
18
+ # --------------------------------------------------------------------------
19
+ from typing import Any , Dict , List , Optional , Tuple , Union
2
20
3
21
import numpy as np
4
22
import PIL
@@ -379,7 +397,7 @@ def visualize_depth(
379
397
val_min : float = 0.0 ,
380
398
val_max : float = 1.0 ,
381
399
color_map : str = "Spectral" ,
382
- ) -> Union [ PIL . Image . Image , List [PIL .Image .Image ] ]:
400
+ ) -> List [PIL .Image .Image ]:
383
401
"""
384
402
Visualizes depth maps, such as predictions of the `MarigoldDepthPipeline`.
385
403
@@ -391,7 +409,7 @@ def visualize_depth(
391
409
color_map (`str`, *optional*, defaults to `"Spectral"`): Color map used to convert a single-channel
392
410
depth prediction into colored representation.
393
411
394
- Returns: `PIL.Image.Image` or ` List[PIL.Image.Image]` with depth maps visualization.
412
+ Returns: `List[PIL.Image.Image]` with depth maps visualization.
395
413
"""
396
414
if val_max <= val_min :
397
415
raise ValueError (f"Invalid values range: [{ val_min } , { val_max } ]." )
@@ -436,7 +454,7 @@ def export_depth_to_16bit_png(
436
454
depth : Union [np .ndarray , torch .Tensor , List [np .ndarray ], List [torch .Tensor ]],
437
455
val_min : float = 0.0 ,
438
456
val_max : float = 1.0 ,
439
- ) -> Union [ PIL . Image . Image , List [PIL .Image .Image ] ]:
457
+ ) -> List [PIL .Image .Image ]:
440
458
def export_depth_to_16bit_png_one (img , idx = None ):
441
459
prefix = "Depth" + (f"[{ idx } ]" if idx else "" )
442
460
if not isinstance (img , np .ndarray ) and not torch .is_tensor (img ):
@@ -478,7 +496,7 @@ def visualize_normals(
478
496
flip_x : bool = False ,
479
497
flip_y : bool = False ,
480
498
flip_z : bool = False ,
481
- ) -> Union [ PIL . Image . Image , List [PIL .Image .Image ] ]:
499
+ ) -> List [PIL .Image .Image ]:
482
500
"""
483
501
Visualizes surface normals, such as predictions of the `MarigoldNormalsPipeline`.
484
502
@@ -492,7 +510,7 @@ def visualize_normals(
492
510
flip_z (`bool`, *optional*, defaults to `False`): Flips the Z axis of the normals frame of reference.
493
511
Default direction is facing the observer.
494
512
495
- Returns: `PIL.Image.Image` or ` List[PIL.Image.Image]` with surface normals visualization.
513
+ Returns: `List[PIL.Image.Image]` with surface normals visualization.
496
514
"""
497
515
flip_vec = None
498
516
if any ((flip_x , flip_y , flip_z )):
@@ -528,6 +546,99 @@ def visualize_normals_one(img, idx=None):
528
546
else :
529
547
raise ValueError (f"Unexpected input type: { type (normals )} " )
530
548
549
+ @staticmethod
550
+ def visualize_intrinsics (
551
+ prediction : Union [
552
+ np .ndarray ,
553
+ torch .Tensor ,
554
+ List [np .ndarray ],
555
+ List [torch .Tensor ],
556
+ ],
557
+ target_properties : Dict [str , Any ],
558
+ color_map : Union [str , Dict [str , str ]] = "binary" ,
559
+ ) -> List [Dict [str , PIL .Image .Image ]]:
560
+ """
561
+ Visualizes intrinsic image decomposition, such as predictions of the `MarigoldIntrinsicsPipeline`.
562
+
563
+ Args:
564
+ prediction (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`):
565
+ Intrinsic image decomposition.
566
+ target_properties (`Dict[str, Any]`):
567
+ Decomposition properties. Expected entries: `target_names: List[str]` and a dictionary with keys
568
+ `prediction_space: str`, `sub_target_names: List[Union[str, Null]]` (must have 3 entries, null for
569
+ missing modalities), `up_to_scale: bool`, one for each target and sub-target.
570
+ color_map (`Union[str, Dict[str, str]]`, *optional*, defaults to `"Spectral"`):
571
+ Color map used to convert a single-channel predictions into colored representations. When a dictionary
572
+ is passed, each modality can be colored with its own color map.
573
+
574
+ Returns: `List[Dict[str, PIL.Image.Image]]` with intrinsic image decomposition visualization.
575
+ """
576
+ if "target_names" not in target_properties :
577
+ raise ValueError ("Missing `target_names` in target_properties" )
578
+ if not isinstance (color_map , str ) and not (
579
+ isinstance (color_map , dict )
580
+ and all (isinstance (k , str ) and isinstance (v , str ) for k , v in color_map .items ())
581
+ ):
582
+ raise ValueError ("`color_map` must be a string or a dictionary of strings" )
583
+ n_targets = len (target_properties ["target_names" ])
584
+
585
+ def visualize_targets_one (images , idx = None ):
586
+ # img: [T, 3, H, W]
587
+ out = {}
588
+ for target_name , img in zip (target_properties ["target_names" ], images ):
589
+ img = img .permute (1 , 2 , 0 ) # [H, W, 3]
590
+ prediction_space = target_properties [target_name ].get ("prediction_space" , "srgb" )
591
+ if prediction_space == "stack" :
592
+ sub_target_names = target_properties [target_name ]["sub_target_names" ]
593
+ if len (sub_target_names ) != 3 or any (
594
+ not (isinstance (s , str ) or s is None ) for s in sub_target_names
595
+ ):
596
+ raise ValueError (f"Unexpected target sub-names { sub_target_names } in { target_name } " )
597
+ for i , sub_target_name in enumerate (sub_target_names ):
598
+ if sub_target_name is None :
599
+ continue
600
+ sub_img = img [:, :, i ]
601
+ sub_prediction_space = target_properties [sub_target_name ].get ("prediction_space" , "srgb" )
602
+ if sub_prediction_space == "linear" :
603
+ sub_up_to_scale = target_properties [sub_target_name ].get ("up_to_scale" , False )
604
+ if sub_up_to_scale :
605
+ sub_img = sub_img / max (sub_img .max ().item (), 1e-6 )
606
+ sub_img = sub_img ** (1 / 2.2 )
607
+ cmap_name = (
608
+ color_map if isinstance (color_map , str ) else color_map .get (sub_target_name , "binary" )
609
+ )
610
+ sub_img = MarigoldImageProcessor .colormap (sub_img , cmap = cmap_name , bytes = True )
611
+ sub_img = PIL .Image .fromarray (sub_img .cpu ().numpy ())
612
+ out [sub_target_name ] = sub_img
613
+ elif prediction_space == "linear" :
614
+ up_to_scale = target_properties [target_name ].get ("up_to_scale" , False )
615
+ if up_to_scale :
616
+ img = img / max (img .max ().item (), 1e-6 )
617
+ img = img ** (1 / 2.2 )
618
+ elif prediction_space == "srgb" :
619
+ pass
620
+ img = (img * 255 ).to (dtype = torch .uint8 , device = "cpu" ).numpy ()
621
+ img = PIL .Image .fromarray (img )
622
+ out [target_name ] = img
623
+ return out
624
+
625
+ if prediction is None or isinstance (prediction , list ) and any (o is None for o in prediction ):
626
+ raise ValueError ("Input prediction is `None`" )
627
+ if isinstance (prediction , (np .ndarray , torch .Tensor )):
628
+ prediction = MarigoldImageProcessor .expand_tensor_or_array (prediction )
629
+ if isinstance (prediction , np .ndarray ):
630
+ prediction = MarigoldImageProcessor .numpy_to_pt (prediction ) # [N*T,3,H,W]
631
+ if not (prediction .ndim == 4 and prediction .shape [1 ] == 3 and prediction .shape [0 ] % n_targets == 0 ):
632
+ raise ValueError (f"Unexpected input shape={ prediction .shape } , expecting [N*T,3,H,W]." )
633
+ N_T , _ , H , W = prediction .shape
634
+ N = N_T // n_targets
635
+ prediction = prediction .reshape (N , n_targets , 3 , H , W )
636
+ return [visualize_targets_one (img , idx ) for idx , img in enumerate (prediction )]
637
+ elif isinstance (prediction , list ):
638
+ return [visualize_targets_one (img , idx ) for idx , img in enumerate (prediction )]
639
+ else :
640
+ raise ValueError (f"Unexpected input type: { type (prediction )} " )
641
+
531
642
@staticmethod
532
643
def visualize_uncertainty (
533
644
uncertainty : Union [
@@ -537,24 +648,26 @@ def visualize_uncertainty(
537
648
List [torch .Tensor ],
538
649
],
539
650
saturation_percentile = 95 ,
540
- ) -> Union [ PIL . Image . Image , List [PIL .Image .Image ] ]:
651
+ ) -> List [PIL .Image .Image ]:
541
652
"""
542
- Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline` or `MarigoldNormalsPipeline`.
653
+ Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline`, `MarigoldNormalsPipeline`, or
654
+ `MarigoldIntrinsicsPipeline`.
543
655
544
656
Args:
545
657
uncertainty (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`):
546
658
Uncertainty maps.
547
659
saturation_percentile (`int`, *optional*, defaults to `95`):
548
660
Specifies the percentile uncertainty value visualized with maximum intensity.
549
661
550
- Returns: `PIL.Image.Image` or ` List[PIL.Image.Image]` with uncertainty visualization.
662
+ Returns: `List[PIL.Image.Image]` with uncertainty visualization.
551
663
"""
552
664
553
665
def visualize_uncertainty_one (img , idx = None ):
554
666
prefix = "Uncertainty" + (f"[{ idx } ]" if idx else "" )
555
667
if img .min () < 0 :
556
- raise ValueError (f"{ prefix } : unexected data range, min={ img .min ()} ." )
557
- img = img .squeeze (0 ).cpu ().numpy ()
668
+ raise ValueError (f"{ prefix } : unexpected data range, min={ img .min ()} ." )
669
+ img = img .permute (1 , 2 , 0 ) # [H,W,C]
670
+ img = img .squeeze (2 ).cpu ().numpy () # [H,W] or [H,W,3]
558
671
saturation_value = np .percentile (img , saturation_percentile )
559
672
img = np .clip (img * 255 / saturation_value , 0 , 255 )
560
673
img = img .astype (np .uint8 )
@@ -566,9 +679,9 @@ def visualize_uncertainty_one(img, idx=None):
566
679
if isinstance (uncertainty , (np .ndarray , torch .Tensor )):
567
680
uncertainty = MarigoldImageProcessor .expand_tensor_or_array (uncertainty )
568
681
if isinstance (uncertainty , np .ndarray ):
569
- uncertainty = MarigoldImageProcessor .numpy_to_pt (uncertainty ) # [N,1 ,H,W]
570
- if not (uncertainty .ndim == 4 and uncertainty .shape [1 ] == 1 ):
571
- raise ValueError (f"Unexpected input shape={ uncertainty .shape } , expecting [N,1 ,H,W]." )
682
+ uncertainty = MarigoldImageProcessor .numpy_to_pt (uncertainty ) # [N,C ,H,W]
683
+ if not (uncertainty .ndim == 4 and uncertainty .shape [1 ] in ( 1 , 3 ) ):
684
+ raise ValueError (f"Unexpected input shape={ uncertainty .shape } , expecting [N,C ,H,W] with C in (1,3) ." )
572
685
return [visualize_uncertainty_one (img , idx ) for idx , img in enumerate (uncertainty )]
573
686
elif isinstance (uncertainty , list ):
574
687
return [visualize_uncertainty_one (img , idx ) for idx , img in enumerate (uncertainty )]
0 commit comments