Skip to content

Commit f007a5e

Browse files
authored
fix resize for segmentation masks without channel dimension (#6576)
* fix resize for segmentation masks without batch dim * micro improvement of stable crop * Revert "micro improvement of stable crop" This reverts commit e981e36.
1 parent ebb68f3 commit f007a5e

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,20 @@ def resize_image_pil(
135135
def resize_segmentation_mask(
136136
segmentation_mask: torch.Tensor, size: List[int], max_size: Optional[int] = None
137137
) -> torch.Tensor:
138-
return resize_image_tensor(segmentation_mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
138+
if segmentation_mask.ndim < 3:
139+
segmentation_mask = segmentation_mask.unsqueeze(0)
140+
needs_squeeze = True
141+
else:
142+
needs_squeeze = False
143+
144+
output = resize_image_tensor(
145+
segmentation_mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size
146+
)
147+
148+
if needs_squeeze:
149+
output = output.squeeze(0)
150+
151+
return output
139152

140153

141154
def resize_bounding_box(

0 commit comments

Comments
 (0)