@@ -1997,6 +1997,97 @@ def post_process_instance_segmentation(
1997
1997
results .append ({"segmentation" : segmentation , "segments_info" : segments })
1998
1998
return results
1999
1999
2000
+ def post_process_instance_segmentation_v2 (
2001
+ self ,
2002
+ outputs ,
2003
+ threshold : float = 0.5 ,
2004
+ mask_threshold : float = 0.5 ,
2005
+ overlap_mask_area_threshold : float = 0.8 ,
2006
+ target_sizes : Optional [List [Tuple [int , int ]]] = None ,
2007
+ ) -> List [Dict ]:
2008
+ """
2009
+ Converts the output of [`DetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch.
2010
+
2011
+ Args:
2012
+ outputs ([`DetrForSegmentation`]):
2013
+ Raw outputs of the model.
2014
+ threshold (`float`, *optional*, defaults to 0.5):
2015
+ The probability score threshold to keep predicted instance masks.
2016
+ mask_threshold (`float`, *optional*, defaults to 0.5):
2017
+ Threshold to use when turning the predicted masks into binary values.
2018
+ overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
2019
+ The overlap mask area threshold to merge or discard small disconnected parts within each binary
2020
+ instance mask.
2021
+ target_sizes (`List[Tuple]`, *optional*):
2022
+ List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
2023
+ final size (height, width) of each prediction. If unset, predictions will not be resized.
2024
+ return_coco_annotation (`bool`, *optional*):
2025
+ Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE)
2026
+ format.
2027
+ return_binary_maps (`bool`, *optional*, defaults to `False`):
2028
+ If set to `True`, segmentation maps are returned as a concatenated tensor of binary segmentation maps
2029
+ (one per detected instance).
2030
+ Returns:
2031
+ `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
2032
+ - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or
2033
+ `List[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to
2034
+ `True`. Set to `None` if no mask if found above `threshold`.
2035
+ - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
2036
+ - **score** -- Prediction score of segment with `segment_id`.
2037
+ """
2038
+
2039
+ # [batch_size, num_queries, num_classes+1]
2040
+ class_queries_logits = outputs .logits
2041
+ # [batch_size, num_queries, height, width]
2042
+ masks_queries_logits = outputs .pred_masks
2043
+
2044
+ device = masks_queries_logits .device
2045
+ num_classes = class_queries_logits .shape [- 1 ] - 1
2046
+ num_queries = class_queries_logits .shape [- 2 ]
2047
+
2048
+ # Loop over items in batch size
2049
+ results : List [Dict [str , TensorType ]] = []
2050
+
2051
+ for i in range (class_queries_logits .shape [0 ]):
2052
+ mask_pred = masks_queries_logits [i ]
2053
+ mask_cls = class_queries_logits [i ]
2054
+
2055
+ scores = torch .nn .functional .softmax (mask_cls , dim = - 1 )[:, :- 1 ]
2056
+ labels = torch .arange (num_classes , device = device ).unsqueeze (0 ).repeat (num_queries , 1 ).flatten (0 , 1 )
2057
+
2058
+ scores_per_image , topk_indices = scores .flatten (0 , 1 ).topk (num_queries , sorted = False )
2059
+ labels_per_image = labels [topk_indices ]
2060
+
2061
+ topk_indices = torch .div (topk_indices , num_classes , rounding_mode = "floor" )
2062
+ mask_pred = mask_pred [topk_indices ]
2063
+ pred_masks = (mask_pred > 0 ).float ()
2064
+
2065
+ # Calculate average mask prob
2066
+ mask_scores_per_image = (mask_pred .sigmoid ().flatten (1 ) * pred_masks .flatten (1 )).sum (1 ) / (
2067
+ pred_masks .flatten (1 ).sum (1 ) + 1e-6
2068
+ )
2069
+ pred_scores = scores_per_image * mask_scores_per_image
2070
+ pred_classes = labels_per_image
2071
+
2072
+ mask_pred , pred_scores , pred_classes = remove_low_and_no_objects (
2073
+ mask_pred , pred_scores , pred_classes , threshold , num_classes
2074
+ )
2075
+
2076
+ segmentation = torch .zeros ((384 , 384 )) - 1
2077
+ if target_sizes is not None :
2078
+ size = target_sizes [i ] if isinstance (target_sizes [i ], tuple ) else target_sizes [i ].cpu ().tolist ()
2079
+ segmentation = torch .zeros (size ) - 1
2080
+ pred_masks = torch .nn .functional .interpolate (pred_masks .unsqueeze (0 ), size = size , mode = "nearest" )[0 ]
2081
+
2082
+ keep = pred_scores > threshold
2083
+
2084
+ score = pred_scores [keep ]
2085
+ label = pred_classes [keep ]
2086
+ segmentation = pred_masks [keep ] > mask_threshold
2087
+
2088
+ results .append ({"score" : score , "label" : label , "segmentation" : segmentation })
2089
+ return results
2090
+
2000
2091
# inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L241
2001
2092
def post_process_panoptic_segmentation (
2002
2093
self ,
0 commit comments