Skip to content

Commit 1ea73f5

Browse files
authored
Rename features.SegmentationMask to features.Mask (#6579)
* rename features.SegmentationMask -> features.Mask * rename kernels *_segmentation_mask -> *_mask and cleanup input name * cleanup * rename module _segmentation_mask.py -> _mask.py * fix test
1 parent f007a5e commit 1ea73f5

12 files changed

+212
-224
lines changed

test/prototype_common_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def make_detection_mask(size=None, *, num_objects=None, extra_dims=(), dtype=tor
184184
num_objects = num_objects if num_objects is not None else int(torch.randint(1, 11, ()))
185185
shape = (*extra_dims, num_objects, *size)
186186
data = make_tensor(shape, low=0, high=2, dtype=dtype)
187-
return features.SegmentationMask(data)
187+
return features.Mask(data)
188188

189189

190190
def make_detection_masks(
@@ -207,7 +207,7 @@ def make_segmentation_mask(size=None, *, num_categories=None, extra_dims=(), dty
207207
num_categories = num_categories if num_categories is not None else int(torch.randint(1, 11, ()))
208208
shape = (*extra_dims, *size)
209209
data = make_tensor(shape, low=0, high=num_categories, dtype=dtype)
210-
return features.SegmentationMask(data)
210+
return features.Mask(data)
211211

212212

213213
def make_segmentation_masks(
@@ -224,7 +224,7 @@ def make_segmentation_masks(
224224
yield make_segmentation_mask(size=sizes[0], num_categories=num_categories_, dtype=dtype, extra_dims=extra_dims_)
225225

226226

227-
def make_detection_and_segmentation_masks(
227+
def make_masks(
228228
sizes=((16, 16), (7, 33), (31, 9)),
229229
dtypes=(torch.uint8,),
230230
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),

test/test_prototype_transforms.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
from prototype_common_utils import (
1111
make_bounding_box,
1212
make_bounding_boxes,
13-
make_detection_and_segmentation_masks,
1413
make_detection_mask,
1514
make_image,
1615
make_images,
1716
make_label,
17+
make_masks,
1818
make_one_hot_labels,
1919
make_segmentation_mask,
2020
)
@@ -64,7 +64,7 @@ def parametrize_from_transforms(*transforms):
6464
make_one_hot_labels,
6565
make_vanilla_tensor_images,
6666
make_pil_images,
67-
make_detection_and_segmentation_masks,
67+
make_masks,
6868
]:
6969
inputs = list(creation_fn())
7070
try:
@@ -132,7 +132,7 @@ def test_mixup_cutmix(self, transform, input):
132132
transform(input_copy)
133133

134134
# Check if we raise an error if sample contains bbox or mask or label
135-
err_msg = "does not support bounding boxes, segmentation masks and plain labels"
135+
err_msg = "does not support bounding boxes, masks and plain labels"
136136
input_copy = dict(input)
137137
for unsup_data in [
138138
make_label(),
@@ -241,7 +241,7 @@ def test_convert_color_space_unsupported_types(self):
241241
color_space=features.ColorSpace.RGB, old_color_space=features.ColorSpace.GRAY
242242
)
243243

244-
for inpt in [make_bounding_box(format="XYXY"), make_detection_and_segmentation_masks()]:
244+
for inpt in [make_bounding_box(format="XYXY"), make_masks()]:
245245
output = transform(inpt)
246246
assert output is inpt
247247

@@ -278,13 +278,13 @@ def test_features_image(self, p):
278278

279279
assert_equal(features.Image(expected), actual)
280280

281-
def test_features_segmentation_mask(self, p):
281+
def test_features_mask(self, p):
282282
input, expected = self.input_expected_image_tensor(p)
283283
transform = transforms.RandomHorizontalFlip(p=p)
284284

285-
actual = transform(features.SegmentationMask(input))
285+
actual = transform(features.Mask(input))
286286

287-
assert_equal(features.SegmentationMask(expected), actual)
287+
assert_equal(features.Mask(expected), actual)
288288

289289
def test_features_bounding_box(self, p):
290290
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
@@ -331,13 +331,13 @@ def test_features_image(self, p):
331331

332332
assert_equal(features.Image(expected), actual)
333333

334-
def test_features_segmentation_mask(self, p):
334+
def test_features_mask(self, p):
335335
input, expected = self.input_expected_image_tensor(p)
336336
transform = transforms.RandomVerticalFlip(p=p)
337337

338-
actual = transform(features.SegmentationMask(input))
338+
actual = transform(features.Mask(input))
339339

340-
assert_equal(features.SegmentationMask(expected), actual)
340+
assert_equal(features.Mask(expected), actual)
341341

342342
def test_features_bounding_box(self, p):
343343
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
@@ -1253,7 +1253,7 @@ def test__transform(self, mocker):
12531253
torch.testing.assert_close(output_ohe_label, ohe_label[is_within_crop_area])
12541254

12551255
output_masks = output[4]
1256-
assert isinstance(output_masks, features.SegmentationMask)
1256+
assert isinstance(output_masks, features.Mask)
12571257
assert len(output_masks) == expected_within_targets
12581258

12591259

@@ -1372,10 +1372,10 @@ def test__extract_image_targets_assertion(self, mocker):
13721372
# labels, bboxes, masks
13731373
mocker.MagicMock(spec=features.Label),
13741374
mocker.MagicMock(spec=features.BoundingBox),
1375-
mocker.MagicMock(spec=features.SegmentationMask),
1375+
mocker.MagicMock(spec=features.Mask),
13761376
# labels, bboxes, masks
13771377
mocker.MagicMock(spec=features.BoundingBox),
1378-
mocker.MagicMock(spec=features.SegmentationMask),
1378+
mocker.MagicMock(spec=features.Mask),
13791379
]
13801380

13811381
with pytest.raises(TypeError, match="requires input sample to contain equal sized list of Images"):
@@ -1393,11 +1393,11 @@ def test__extract_image_targets(self, image_type, label_type, mocker):
13931393
# labels, bboxes, masks
13941394
mocker.MagicMock(spec=label_type),
13951395
mocker.MagicMock(spec=features.BoundingBox),
1396-
mocker.MagicMock(spec=features.SegmentationMask),
1396+
mocker.MagicMock(spec=features.Mask),
13971397
# labels, bboxes, masks
13981398
mocker.MagicMock(spec=label_type),
13991399
mocker.MagicMock(spec=features.BoundingBox),
1400-
mocker.MagicMock(spec=features.SegmentationMask),
1400+
mocker.MagicMock(spec=features.Mask),
14011401
]
14021402

14031403
images, targets = transform._extract_image_targets(flat_sample)
@@ -1413,7 +1413,7 @@ def test__extract_image_targets(self, image_type, label_type, mocker):
14131413
for target in targets:
14141414
for key, type_ in [
14151415
("boxes", features.BoundingBox),
1416-
("masks", features.SegmentationMask),
1416+
("masks", features.Mask),
14171417
("labels", label_type),
14181418
]:
14191419
assert key in target
@@ -1436,7 +1436,7 @@ def test__copy_paste(self, label_type):
14361436
"boxes": features.BoundingBox(
14371437
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", image_size=(32, 32)
14381438
),
1439-
"masks": features.SegmentationMask(masks),
1439+
"masks": features.Mask(masks),
14401440
"labels": label_type(labels),
14411441
}
14421442

@@ -1451,7 +1451,7 @@ def test__copy_paste(self, label_type):
14511451
"boxes": features.BoundingBox(
14521452
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", image_size=(32, 32)
14531453
),
1454-
"masks": features.SegmentationMask(paste_masks),
1454+
"masks": features.Mask(paste_masks),
14551455
"labels": label_type(paste_labels),
14561456
}
14571457

@@ -1586,7 +1586,7 @@ def test__transform_culling(self, mocker):
15861586
bounding_boxes = make_bounding_box(
15871587
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,)
15881588
)
1589-
segmentation_masks = make_detection_mask(size=image_size, extra_dims=(batch_size,))
1589+
masks = make_detection_mask(size=image_size, extra_dims=(batch_size,))
15901590
labels = make_label(size=(batch_size,))
15911591

15921592
transform = transforms.FixedSizeCrop((-1, -1))
@@ -1596,13 +1596,13 @@ def test__transform_culling(self, mocker):
15961596
output = transform(
15971597
dict(
15981598
bounding_boxes=bounding_boxes,
1599-
segmentation_masks=segmentation_masks,
1599+
masks=masks,
16001600
labels=labels,
16011601
)
16021602
)
16031603

16041604
assert_equal(output["bounding_boxes"], bounding_boxes[is_valid])
1605-
assert_equal(output["segmentation_masks"], segmentation_masks[is_valid])
1605+
assert_equal(output["masks"], masks[is_valid])
16061606
assert_equal(output["labels"], labels[is_valid])
16071607

16081608
def test__transform_bounding_box_clamping(self, mocker):

0 commit comments

Comments
 (0)