-
Notifications
You must be signed in to change notification settings - Fork 7.1k
masks_to_bounding_boxes
op
#4290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cf51379
c67e035
3830dd1
926d444
f777416
cd46aa7
712131e
b6f5c42
b555c68
fc26f3a
c4d3045
4589951
6b19d67
c6c89ec
16a99a9
7115320
f4796d2
a070133
0131db3
0a23bcf
db8fb7b
f7a2c1e
c7dfcdf
5e6198a
7c78271
b9055c2
6c630c5
540c6a1
8e4fc2f
4c78297
140e429
8f2cd4a
7252723
26f68af
2c2d5dd
3a91957
e24805c
65404e9
6c89be7
b2a907c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
sphinx==3.5.4 | ||
sphinx-gallery>=0.9.0 | ||
sphinx-copybutton>=0.3.1 | ||
matplotlib | ||
numpy | ||
sphinx-copybutton>=0.3.1 | ||
sphinx-gallery>=0.9.0 | ||
sphinx==3.5.4 | ||
-e git+git://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
""" | ||
======================= | ||
Repurposing annotations | ||
======================= | ||
|
||
The following example illustrates the operations available in the torchvision.ops module for repurposing object | ||
localization annotations for different tasks (e.g. transforming masks used by instance and panoptic segmentation | ||
methods into bounding boxes used by object detection methods). | ||
""" | ||
import os.path | ||
|
||
import PIL.Image | ||
import matplotlib.patches | ||
import matplotlib.pyplot | ||
import numpy | ||
import torch | ||
from torchvision.ops import masks_to_boxes | ||
|
||
ASSETS_DIRECTORY = "../test/assets" | ||
|
||
matplotlib.pyplot.rcParams["savefig.bbox"] = "tight" | ||
|
||
#################################### | ||
# Masks | ||
# ----- | ||
# In tasks like instance and panoptic segmentation, masks are commonly defined, and are defined by this package, | ||
# as a multi-dimensional array (e.g. a NumPy array or a PyTorch tensor) with the following shape: | ||
# | ||
# (objects, height, width) | ||
# | ||
# Where objects is the number of annotated objects in the image. Each (height, width) object corresponds to exactly | ||
# one object. For example, if your input image has the dimensions 224 x 224 and has four annotated objects the shape | ||
# of your masks annotation has the following shape: | ||
# | ||
# (4, 224, 224). | ||
# | ||
# A nice property of masks is that they can be easily repurposed to be used in methods to solve a variety of object | ||
# localization tasks. | ||
# | ||
# Masks to bounding boxes | ||
# ---------------------------------------- | ||
# For example, the masks to bounding_boxes operation can be used to transform masks into bounding boxes that can be | ||
# used in methods like Faster RCNN and YOLO. | ||
|
||
with PIL.Image.open(os.path.join(ASSETS_DIRECTORY, "masks.tiff")) as image: | ||
masks = torch.zeros((image.n_frames, image.height, image.width), dtype=torch.int) | ||
|
||
for index in range(image.n_frames): | ||
image.seek(index) | ||
|
||
frame = numpy.array(image) | ||
|
||
masks[index] = torch.tensor(frame) | ||
|
||
bounding_boxes = masks_to_boxes(masks) | ||
|
||
figure = matplotlib.pyplot.figure() | ||
|
||
a = figure.add_subplot(121) | ||
b = figure.add_subplot(122) | ||
|
||
labeled_image = torch.sum(masks, 0) | ||
|
||
a.imshow(labeled_image) | ||
b.imshow(labeled_image) | ||
|
||
for bounding_box in bounding_boxes: | ||
x0, y0, x1, y1 = bounding_box | ||
|
||
rectangle = matplotlib.patches.Rectangle((x0, y0), x1 - x0, y1 - y0, linewidth=1, edgecolor="r", facecolor="none") | ||
|
||
b.add_patch(rectangle) | ||
|
||
a.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) | ||
b.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import os.path | ||
|
||
import PIL.Image | ||
import numpy | ||
import torch | ||
|
||
from torchvision.ops import masks_to_boxes | ||
|
||
ASSETS_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") | ||
|
||
|
||
def test_masks_to_boxes(): | ||
with PIL.Image.open(os.path.join(ASSETS_DIRECTORY, "masks.tiff")) as image: | ||
masks = torch.zeros((image.n_frames, image.height, image.width), dtype=torch.int) | ||
|
||
for index in range(image.n_frames): | ||
image.seek(index) | ||
|
||
frame = numpy.array(image) | ||
|
||
masks[index] = torch.tensor(frame) | ||
|
||
expected = torch.tensor( | ||
[[127, 2, 165, 40], | ||
[2, 50, 44, 92], | ||
[56, 63, 98, 100], | ||
[139, 68, 175, 104], | ||
[160, 112, 198, 145], | ||
[49, 138, 99, 182], | ||
[108, 148, 152, 213]], | ||
dtype=torch.int32 | ||
) | ||
|
||
torch.testing.assert_close(masks_to_boxes(masks), expected) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -297,3 +297,35 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: | |
areai = whi[:, :, 0] * whi[:, :, 1] | ||
|
||
return iou - (areai - union) / areai | ||
|
||
|
||
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Compute the bounding boxes around the provided masks | ||
|
||
Returns a [N, 4] tensor. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with | ||
``0 <= x1 < x2`` and ``0 <= y1 < y2``. | ||
|
||
Args: | ||
masks (Tensor[N, H, W]): masks to transform where N is the number of | ||
masks and (H, W) are the spatial dimensions. | ||
|
||
Returns: | ||
Tensor[N, 4]: bounding boxes | ||
""" | ||
if masks.numel() == 0: | ||
return torch.zeros((0, 4)) | ||
|
||
n = masks.shape[0] | ||
|
||
bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.int) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My initial thought was dtype should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, also the above zeros needs to have a device: Could you please send a PR that fixes these 2 issues? The rest of the doc/test improvements discussed here can happen on a separate PR. |
||
|
||
for index, mask in enumerate(masks): | ||
y, x = torch.where(masks[index] != 0) | ||
|
||
bounding_boxes[index, 0] = torch.min(x) | ||
bounding_boxes[index, 1] = torch.min(y) | ||
bounding_boxes[index, 2] = torch.max(x) | ||
bounding_boxes[index, 3] = torch.max(y) | ||
|
||
return bounding_boxes |
Uh oh!
There was an error while loading. Please reload this page.