Skip to content

Commit e4984d9

Browse files
authored
Merge pull request huggingface#16 from Superb-AI-Suite/feat/triton_format_processing
make style & add postprocssing for instance segmentation compatible for triton
2 parents b015820 + c50a6f7 commit e4984d9

File tree

3 files changed

+177
-1
lines changed

3 files changed

+177
-1
lines changed

src/transformers/models/detr/image_processing_detr.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1997,6 +1997,97 @@ def post_process_instance_segmentation(
19971997
results.append({"segmentation": segmentation, "segments_info": segments})
19981998
return results
19991999

2000+
def post_process_instance_segmentation_v2(
2001+
self,
2002+
outputs,
2003+
threshold: float = 0.5,
2004+
mask_threshold: float = 0.5,
2005+
overlap_mask_area_threshold: float = 0.8,
2006+
target_sizes: Optional[List[Tuple[int, int]]] = None,
2007+
) -> List[Dict]:
2008+
"""
2009+
Converts the output of [`DetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch.
2010+
2011+
Args:
2012+
outputs ([`DetrForSegmentation`]):
2013+
Raw outputs of the model.
2014+
threshold (`float`, *optional*, defaults to 0.5):
2015+
The probability score threshold to keep predicted instance masks.
2016+
mask_threshold (`float`, *optional*, defaults to 0.5):
2017+
Threshold to use when turning the predicted masks into binary values.
2018+
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
2019+
The overlap mask area threshold to merge or discard small disconnected parts within each binary
2020+
instance mask.
2021+
target_sizes (`List[Tuple]`, *optional*):
2022+
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
2023+
final size (height, width) of each prediction. If unset, predictions will not be resized.
2024+
return_coco_annotation (`bool`, *optional*):
2025+
Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE)
2026+
format.
2027+
return_binary_maps (`bool`, *optional*, defaults to `False`):
2028+
If set to `True`, segmentation maps are returned as a concatenated tensor of binary segmentation maps
2029+
(one per detected instance).
2030+
Returns:
2031+
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
2032+
- **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or
2033+
`List[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to
2034+
`True`. Set to `None` if no mask if found above `threshold`.
2035+
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
2036+
- **score** -- Prediction score of segment with `segment_id`.
2037+
"""
2038+
2039+
# [batch_size, num_queries, num_classes+1]
2040+
class_queries_logits = outputs.logits
2041+
# [batch_size, num_queries, height, width]
2042+
masks_queries_logits = outputs.pred_masks
2043+
2044+
device = masks_queries_logits.device
2045+
num_classes = class_queries_logits.shape[-1] - 1
2046+
num_queries = class_queries_logits.shape[-2]
2047+
2048+
# Loop over items in batch size
2049+
results: List[Dict[str, TensorType]] = []
2050+
2051+
for i in range(class_queries_logits.shape[0]):
2052+
mask_pred = masks_queries_logits[i]
2053+
mask_cls = class_queries_logits[i]
2054+
2055+
scores = torch.nn.functional.softmax(mask_cls, dim=-1)[:, :-1]
2056+
labels = torch.arange(num_classes, device=device).unsqueeze(0).repeat(num_queries, 1).flatten(0, 1)
2057+
2058+
scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False)
2059+
labels_per_image = labels[topk_indices]
2060+
2061+
topk_indices = torch.div(topk_indices, num_classes, rounding_mode="floor")
2062+
mask_pred = mask_pred[topk_indices]
2063+
pred_masks = (mask_pred > 0).float()
2064+
2065+
# Calculate average mask prob
2066+
mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * pred_masks.flatten(1)).sum(1) / (
2067+
pred_masks.flatten(1).sum(1) + 1e-6
2068+
)
2069+
pred_scores = scores_per_image * mask_scores_per_image
2070+
pred_classes = labels_per_image
2071+
2072+
mask_pred, pred_scores, pred_classes = remove_low_and_no_objects(
2073+
mask_pred, pred_scores, pred_classes, threshold, num_classes
2074+
)
2075+
2076+
segmentation = torch.zeros((384, 384)) - 1
2077+
if target_sizes is not None:
2078+
size = target_sizes[i] if isinstance(target_sizes[i], tuple) else target_sizes[i].cpu().tolist()
2079+
segmentation = torch.zeros(size) - 1
2080+
pred_masks = torch.nn.functional.interpolate(pred_masks.unsqueeze(0), size=size, mode="nearest")[0]
2081+
2082+
keep = pred_scores > threshold
2083+
2084+
score = pred_scores[keep]
2085+
label = pred_classes[keep]
2086+
segmentation = pred_masks[keep] > mask_threshold
2087+
2088+
results.append({"score": score, "label": label, "segmentation": segmentation})
2089+
return results
2090+
20002091
# inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L241
20012092
def post_process_panoptic_segmentation(
20022093
self,

src/transformers/models/mask2former/image_processing_mask2former.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,92 @@ def post_process_instance_segmentation(
11401140
results.append({"segmentation": segmentation, "segments_info": segments})
11411141
return results
11421142

1143+
def post_process_instance_segmentation_v2(
1144+
self,
1145+
outputs,
1146+
threshold: float = 0.5,
1147+
mask_threshold: float = 0.5,
1148+
overlap_mask_area_threshold: float = 0.8,
1149+
target_sizes: Optional[List[Tuple[int, int]]] = None,
1150+
) -> List[Dict]:
1151+
"""
1152+
Converts the output of [`Mask2FormerForUniversalSegmentationOutput`] into instance segmentation predictions.
1153+
Only supports PyTorch.
1154+
1155+
Args:
1156+
outputs ([`Mask2FormerForUniversalSegmentation`]):
1157+
Raw outputs of the model.
1158+
threshold (`float`, *optional*, defaults to 0.5):
1159+
The probability score threshold to keep predicted instance masks.
1160+
mask_threshold (`float`, *optional*, defaults to 0.5):
1161+
Threshold to use when turning the predicted masks into binary values.
1162+
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
1163+
The overlap mask area threshold to merge or discard small disconnected parts within each binary
1164+
instance mask.
1165+
target_sizes (`List[Tuple]`, *optional*):
1166+
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
1167+
final size (height, width) of each prediction. If left to None, predictions will not be resized.
1168+
Returns:
1169+
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
1170+
- **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or
1171+
`List[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to
1172+
`True`. Set to `None` if no mask if found above `threshold`.
1173+
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
1174+
- **score** -- Prediction score of segment with `segment_id`.
1175+
"""
1176+
# [batch_size, num_queries, num_classes+1]
1177+
class_queries_logits = outputs.class_queries_logits
1178+
# [batch_size, num_queries, height, width]
1179+
masks_queries_logits = outputs.masks_queries_logits
1180+
1181+
# Scale back to preprocessed image size - (384, 384) for all models
1182+
masks_queries_logits = torch.nn.functional.interpolate(
1183+
masks_queries_logits, size=(384, 384), mode="bilinear", align_corners=False
1184+
)
1185+
1186+
device = masks_queries_logits.device
1187+
num_classes = class_queries_logits.shape[-1] - 1
1188+
num_queries = class_queries_logits.shape[-2]
1189+
1190+
# Loop over items in batch size
1191+
results: List[Dict[str, TensorType]] = []
1192+
1193+
for i in range(class_queries_logits.shape[0]):
1194+
mask_pred = masks_queries_logits[i]
1195+
mask_cls = class_queries_logits[i]
1196+
1197+
scores = torch.nn.functional.softmax(mask_cls, dim=-1)[:, :-1]
1198+
labels = torch.arange(num_classes, device=device).unsqueeze(0).repeat(num_queries, 1).flatten(0, 1)
1199+
1200+
scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False)
1201+
labels_per_image = labels[topk_indices]
1202+
1203+
topk_indices = torch.div(topk_indices, num_classes, rounding_mode="floor")
1204+
mask_pred = mask_pred[topk_indices]
1205+
pred_masks = (mask_pred > 0).float()
1206+
1207+
# Calculate average mask prob
1208+
mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * pred_masks.flatten(1)).sum(1) / (
1209+
pred_masks.flatten(1).sum(1) + 1e-6
1210+
)
1211+
pred_scores = scores_per_image * mask_scores_per_image
1212+
pred_classes = labels_per_image
1213+
1214+
segmentation = torch.zeros((384, 384)) - 1
1215+
if target_sizes is not None:
1216+
size = target_sizes[i] if isinstance(target_sizes[i], tuple) else target_sizes[i].cpu().tolist()
1217+
segmentation = torch.zeros(size) - 1
1218+
pred_masks = torch.nn.functional.interpolate(pred_masks.unsqueeze(0), size=size, mode="nearest")[0]
1219+
1220+
keep = pred_scores > threshold
1221+
1222+
score = pred_scores[keep]
1223+
label = pred_classes[keep]
1224+
segmentation = pred_masks[keep] > mask_threshold
1225+
1226+
results.append({"score": score, "label": label, "segmentation": segmentation})
1227+
return results
1228+
11431229
def post_process_panoptic_segmentation(
11441230
self,
11451231
outputs,

src/transformers/models/yolov6/image_processing_yolov6.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515
"""Image processor class for YOLOS."""
1616

17-
import math
1817
import pathlib
1918
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
2019

0 commit comments

Comments
 (0)