Skip to content

Commit b5c961d

Browse files
authored
[proto] Fixed xfailed geom functional segm mask tests (#6546)
1 parent 15a9a93 commit b5c961d

File tree

1 file changed

+23
-39
lines changed

1 file changed

+23
-39
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 23 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def make_segmentation_masks(
163163
yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_)
164164

165165
for dtype, extra_dims_, num_objects_ in itertools.product(dtypes, extra_dims, num_objects):
166-
yield make_segmentation_mask(num_objects=num_objects_, dtype=dtype, extra_dims=extra_dims_)
166+
yield make_segmentation_mask(size=sizes[0], num_objects=num_objects_, dtype=dtype, extra_dims=extra_dims_)
167167

168168

169169
class SampleInput:
@@ -904,21 +904,14 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
904904
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
905905

906906

907-
incorrect_expected_segmentation_mask_setup = pytest.mark.xfail(
908-
reason="This test fails because the expected result computation is wrong. Fix ASAP.",
909-
strict=False,
910-
)
911-
912-
913-
@incorrect_expected_segmentation_mask_setup
914907
@pytest.mark.parametrize("angle", [-54, 56])
915908
@pytest.mark.parametrize("translate", [-7, 8])
916909
@pytest.mark.parametrize("scale", [0.89, 1.12])
917910
@pytest.mark.parametrize("shear", [4])
918911
@pytest.mark.parametrize("center", [None, (12, 14)])
919912
def test_correctness_affine_segmentation_mask(angle, translate, scale, shear, center):
920913
def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_):
921-
assert mask.ndim == 3 and mask.shape[0] == 1
914+
assert mask.ndim == 3
922915
affine_matrix = _compute_affine_matrix(angle_, translate_, scale_, shear_, center_)
923916
inv_affine_matrix = np.linalg.inv(affine_matrix)
924917
inv_affine_matrix = inv_affine_matrix[:2, :]
@@ -927,10 +920,11 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_):
927920
for out_y in range(expected_mask.shape[1]):
928921
for out_x in range(expected_mask.shape[2]):
929922
output_pt = np.array([out_x + 0.5, out_y + 0.5, 1.0])
930-
input_pt = np.floor(np.dot(inv_affine_matrix, output_pt)).astype(np.int32)
923+
input_pt = np.floor(np.dot(inv_affine_matrix, output_pt)).astype("int")
931924
in_x, in_y = input_pt[:2]
932925
if 0 <= in_x < mask.shape[2] and 0 <= in_y < mask.shape[1]:
933-
expected_mask[0, out_y, out_x] = mask[0, in_y, in_x]
926+
for i in range(expected_mask.shape[0]):
927+
expected_mask[i, out_y, out_x] = mask[i, in_y, in_x]
934928
return expected_mask.to(mask.device)
935929

936930
for mask in make_segmentation_masks(extra_dims=((), (4,))):
@@ -1128,13 +1122,12 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
11281122
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
11291123

11301124

1131-
@incorrect_expected_segmentation_mask_setup
1132-
@pytest.mark.parametrize("angle", range(-90, 90, 37))
1125+
@pytest.mark.parametrize("angle", range(-89, 90, 37))
11331126
@pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))])
11341127
def test_correctness_rotate_segmentation_mask(angle, expand, center):
11351128
def _compute_expected_mask(mask, angle_, expand_, center_):
1136-
assert mask.ndim == 3 and mask.shape[0] == 1
1137-
image_size = mask.shape[-2:]
1129+
assert mask.ndim == 3
1130+
c, *image_size = mask.shape
11381131
affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_)
11391132
inv_affine_matrix = np.linalg.inv(affine_matrix)
11401133

@@ -1155,22 +1148,23 @@ def _compute_expected_mask(mask, angle_, expand_, center_):
11551148
max_vals = np.max(new_points, axis=0)[:2]
11561149
cmax = np.ceil(np.trunc(max_vals * 1e4) * 1e-4)
11571150
cmin = np.floor(np.trunc((min_vals + 1e-8) * 1e4) * 1e-4)
1158-
new_width, new_height = (cmax - cmin).astype("int32").tolist()
1151+
new_width, new_height = (cmax - cmin).astype("int").tolist()
11591152
tr = np.array([-(new_width - width) / 2.0, -(new_height - height) / 2.0, 1.0]) @ inv_affine_matrix.T
11601153

11611154
inv_affine_matrix[:2, 2] = tr[:2]
11621155
image_size = [new_height, new_width]
11631156

11641157
inv_affine_matrix = inv_affine_matrix[:2, :]
1165-
expected_mask = torch.zeros(1, *image_size, dtype=mask.dtype)
1158+
expected_mask = torch.zeros(c, *image_size, dtype=mask.dtype)
11661159

11671160
for out_y in range(expected_mask.shape[1]):
11681161
for out_x in range(expected_mask.shape[2]):
11691162
output_pt = np.array([out_x + 0.5, out_y + 0.5, 1.0])
1170-
input_pt = np.floor(np.dot(inv_affine_matrix, output_pt)).astype(np.int32)
1163+
input_pt = np.floor(np.dot(inv_affine_matrix, output_pt)).astype("int")
11711164
in_x, in_y = input_pt[:2]
11721165
if 0 <= in_x < mask.shape[2] and 0 <= in_y < mask.shape[1]:
1173-
expected_mask[0, out_y, out_x] = mask[0, in_y, in_x]
1166+
for i in range(expected_mask.shape[0]):
1167+
expected_mask[i, out_y, out_x] = mask[i, in_y, in_x]
11741168
return expected_mask.to(mask.device)
11751169

11761170
for mask in make_segmentation_masks(extra_dims=((), (4,))):
@@ -1617,7 +1611,6 @@ def _compute_expected_bbox(bbox, pcoeffs_):
16171611
torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=1e-5, atol=1e-5)
16181612

16191613

1620-
@incorrect_expected_segmentation_mask_setup
16211614
@pytest.mark.parametrize("device", cpu_and_gpu())
16221615
@pytest.mark.parametrize(
16231616
"startpoints, endpoints",
@@ -1629,19 +1622,9 @@ def _compute_expected_bbox(bbox, pcoeffs_):
16291622
)
16301623
def test_correctness_perspective_segmentation_mask(device, startpoints, endpoints):
16311624
def _compute_expected_mask(mask, pcoeffs_):
1632-
assert mask.ndim == 3 and mask.shape[0] == 1
1633-
m1 = np.array(
1634-
[
1635-
[pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]],
1636-
[pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]],
1637-
]
1638-
)
1639-
m2 = np.array(
1640-
[
1641-
[pcoeffs_[6], pcoeffs_[7], 1.0],
1642-
[pcoeffs_[6], pcoeffs_[7], 1.0],
1643-
]
1644-
)
1625+
assert mask.ndim == 3
1626+
m1 = np.array([[pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]], [pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]]])
1627+
m2 = np.array([[pcoeffs_[6], pcoeffs_[7], 1.0], [pcoeffs_[6], pcoeffs_[7], 1.0]])
16451628

16461629
expected_mask = torch.zeros_like(mask.cpu())
16471630
for out_y in range(expected_mask.shape[1]):
@@ -1654,7 +1637,8 @@ def _compute_expected_mask(mask, pcoeffs_):
16541637

16551638
in_x, in_y = input_pt[:2]
16561639
if 0 <= in_x < mask.shape[2] and 0 <= in_y < mask.shape[1]:
1657-
expected_mask[0, out_y, out_x] = mask[0, in_y, in_x]
1640+
for i in range(expected_mask.shape[0]):
1641+
expected_mask[i, out_y, out_x] = mask[i, in_y, in_x]
16581642
return expected_mask.to(mask.device)
16591643

16601644
pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
@@ -1819,7 +1803,6 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s
18191803
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")
18201804

18211805

1822-
@incorrect_expected_segmentation_mask_setup
18231806
@pytest.mark.parametrize("device", cpu_and_gpu())
18241807
@pytest.mark.parametrize(
18251808
"fn, make_samples", [(F.elastic_image_tensor, make_images), (F.elastic_segmentation_mask, make_segmentation_masks)]
@@ -1829,10 +1812,11 @@ def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples):
18291812
for sample in make_samples(sizes=((64, 76),), extra_dims=((), (4,))):
18301813
c, h, w = sample.shape[-3:]
18311814
# Setup a dummy image with 4 points
1832-
sample[..., in_box[1], in_box[0]] = torch.tensor([12, 34, 96, 112])[:c]
1833-
sample[..., in_box[3] - 1, in_box[0]] = torch.tensor([12, 34, 96, 112])[:c]
1834-
sample[..., in_box[3] - 1, in_box[2] - 1] = torch.tensor([12, 34, 96, 112])[:c]
1835-
sample[..., in_box[1], in_box[2] - 1] = torch.tensor([12, 34, 96, 112])[:c]
1815+
print(sample.shape)
1816+
sample[..., in_box[1], in_box[0]] = torch.arange(10, 10 + c)
1817+
sample[..., in_box[3] - 1, in_box[0]] = torch.arange(20, 20 + c)
1818+
sample[..., in_box[3] - 1, in_box[2] - 1] = torch.arange(30, 30 + c)
1819+
sample[..., in_box[1], in_box[2] - 1] = torch.arange(40, 40 + c)
18361820
sample = sample.to(device)
18371821

18381822
if fn == F.elastic_image_tensor:

0 commit comments

Comments
 (0)