|
6 | 6 |
|
7 | 7 | import pytest
|
8 | 8 | import torch
|
9 |
| -from common_utils import assert_equal |
| 9 | +from common_utils import assert_equal, cpu_and_gpu |
10 | 10 | from test_prototype_transforms_functional import (
|
11 | 11 | make_bounding_box,
|
12 | 12 | make_bounding_boxes,
|
|
15 | 15 | make_one_hot_labels,
|
16 | 16 | make_segmentation_mask,
|
17 | 17 | )
|
| 18 | +from torchvision.ops.boxes import box_iou |
18 | 19 | from torchvision.prototype import features, transforms
|
19 | 20 | from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
|
20 | 21 |
|
@@ -1127,6 +1128,124 @@ def test_ctor(self, trfms):
|
1127 | 1128 | assert isinstance(output, torch.Tensor)
|
1128 | 1129 |
|
1129 | 1130 |
|
| 1131 | +class TestRandomIoUCrop: |
| 1132 | + @pytest.mark.parametrize("device", cpu_and_gpu()) |
| 1133 | + @pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]]) |
| 1134 | + def test__get_params(self, device, options, mocker): |
| 1135 | + image = mocker.MagicMock(spec=features.Image) |
| 1136 | + image.num_channels = 3 |
| 1137 | + image.image_size = (24, 32) |
| 1138 | + bboxes = features.BoundingBox( |
| 1139 | + torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]), |
| 1140 | + format="XYXY", |
| 1141 | + image_size=image.image_size, |
| 1142 | + device=device, |
| 1143 | + ) |
| 1144 | + sample = [image, bboxes] |
| 1145 | + |
| 1146 | + transform = transforms.RandomIoUCrop(sampler_options=options) |
| 1147 | + |
| 1148 | + n_samples = 5 |
| 1149 | + for _ in range(n_samples): |
| 1150 | + |
| 1151 | + params = transform._get_params(sample) |
| 1152 | + |
| 1153 | + if options == [2.0]: |
| 1154 | + assert len(params) == 0 |
| 1155 | + return |
| 1156 | + |
| 1157 | + assert len(params["is_within_crop_area"]) > 0 |
| 1158 | + assert params["is_within_crop_area"].dtype == torch.bool |
| 1159 | + |
| 1160 | + orig_h = image.image_size[0] |
| 1161 | + orig_w = image.image_size[1] |
| 1162 | + assert int(transform.min_scale * orig_h) <= params["height"] <= int(transform.max_scale * orig_h) |
| 1163 | + assert int(transform.min_scale * orig_w) <= params["width"] <= int(transform.max_scale * orig_w) |
| 1164 | + |
| 1165 | + left, top = params["left"], params["top"] |
| 1166 | + new_h, new_w = params["height"], params["width"] |
| 1167 | + ious = box_iou( |
| 1168 | + bboxes, |
| 1169 | + torch.tensor([[left, top, left + new_w, top + new_h]], dtype=bboxes.dtype, device=bboxes.device), |
| 1170 | + ) |
| 1171 | + assert ious.max() >= options[0] or ious.max() >= options[1], f"{ious} vs {options}" |
| 1172 | + |
| 1173 | + def test__transform_empty_params(self, mocker): |
| 1174 | + transform = transforms.RandomIoUCrop(sampler_options=[2.0]) |
| 1175 | + image = features.Image(torch.rand(1, 3, 4, 4)) |
| 1176 | + bboxes = features.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", image_size=(4, 4)) |
| 1177 | + label = features.Label(torch.tensor([1])) |
| 1178 | + sample = [image, bboxes, label] |
| 1179 | + # Let's mock transform._get_params to control the output: |
| 1180 | + transform._get_params = mocker.MagicMock(return_value={}) |
| 1181 | + output = transform(sample) |
| 1182 | + torch.testing.assert_close(output, sample) |
| 1183 | + |
| 1184 | + def test_forward_assertion(self): |
| 1185 | + transform = transforms.RandomIoUCrop() |
| 1186 | + with pytest.raises( |
| 1187 | + TypeError, |
| 1188 | + match="requires input sample to contain Images or PIL Images, BoundingBoxes and Labels or OneHotLabels", |
| 1189 | + ): |
| 1190 | + transform(torch.tensor(0)) |
| 1191 | + |
| 1192 | + def test__transform(self, mocker): |
| 1193 | + transform = transforms.RandomIoUCrop() |
| 1194 | + |
| 1195 | + image = features.Image(torch.rand(3, 32, 24)) |
| 1196 | + bboxes = make_bounding_box(format="XYXY", image_size=(32, 24), extra_dims=(6,)) |
| 1197 | + label = features.Label(torch.randint(0, 10, size=(6,))) |
| 1198 | + ohe_label = features.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1)) |
| 1199 | + masks = make_segmentation_mask((32, 24)) |
| 1200 | + ohe_masks = features.SegmentationMask(torch.randint(0, 2, size=(6, 32, 24))) |
| 1201 | + sample = [image, bboxes, label, ohe_label, masks, ohe_masks] |
| 1202 | + |
| 1203 | + fn = mocker.patch("torchvision.prototype.transforms.functional.crop", side_effect=lambda x, **params: x) |
| 1204 | + is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool) |
| 1205 | + |
| 1206 | + params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area) |
| 1207 | + transform._get_params = mocker.MagicMock(return_value=params) |
| 1208 | + output = transform(sample) |
| 1209 | + |
| 1210 | + assert fn.call_count == 4 |
| 1211 | + |
| 1212 | + expected_calls = [ |
| 1213 | + mocker.call(image, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), |
| 1214 | + mocker.call(bboxes, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), |
| 1215 | + mocker.call(masks, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), |
| 1216 | + mocker.call( |
| 1217 | + ohe_masks, top=params["top"], left=params["left"], height=params["height"], width=params["width"] |
| 1218 | + ), |
| 1219 | + ] |
| 1220 | + |
| 1221 | + fn.assert_has_calls(expected_calls) |
| 1222 | + |
| 1223 | + expected_within_targets = sum(is_within_crop_area) |
| 1224 | + |
| 1225 | + # check number of bboxes vs number of labels: |
| 1226 | + output_bboxes = output[1] |
| 1227 | + assert isinstance(output_bboxes, features.BoundingBox) |
| 1228 | + assert len(output_bboxes) == expected_within_targets |
| 1229 | + |
| 1230 | + # check labels |
| 1231 | + output_label = output[2] |
| 1232 | + assert isinstance(output_label, features.Label) |
| 1233 | + assert len(output_label) == expected_within_targets |
| 1234 | + torch.testing.assert_close(output_label, label[is_within_crop_area]) |
| 1235 | + |
| 1236 | + output_ohe_label = output[3] |
| 1237 | + assert isinstance(output_ohe_label, features.OneHotLabel) |
| 1238 | + torch.testing.assert_close(output_ohe_label, ohe_label[is_within_crop_area]) |
| 1239 | + |
| 1240 | + output_masks = output[4] |
| 1241 | + assert isinstance(output_masks, features.SegmentationMask) |
| 1242 | + assert output_masks.shape[:-2] == masks.shape[:-2] |
| 1243 | + |
| 1244 | + output_ohe_masks = output[5] |
| 1245 | + assert isinstance(output_ohe_masks, features.SegmentationMask) |
| 1246 | + assert len(output_ohe_masks) == expected_within_targets |
| 1247 | + |
| 1248 | + |
1130 | 1249 | class TestScaleJitter:
|
1131 | 1250 | def test__get_params(self, mocker):
|
1132 | 1251 | image_size = (24, 32)
|
|
0 commit comments