Skip to content

Commit 1a90c5f

Browse files
authored
AA functions: add more stringent label/shape type compatibility checking (#9005)
These changes are meant to enforce several conditions: 1. A function should not be able to create shapes of type incompatible with the type of the label it declares (for example, if a function's label has type `rectangle`, it should not be able to create ellipses with that label). 2. A function should not be able to create shapes of type incompatible with the type of the label in the task being annotated. 3. If a function's declared label has a type incompatible with the type of the corresponding task label, then it should not run at all (since it would be impossible for it to output a shape with that label that wouldn't violate either condition 1 or 2). Altogether, these restrictions ensure that we don't create any shapes in a task that aren't compatible with that task's label types. In addition, set explicit label types for the predefined functions.
1 parent 49d4eaa commit 1a90c5f

File tree

7 files changed

+199
-64
lines changed

7 files changed

+199
-64
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
### Added
2+
3+
- \[SDK\] The shapes output by auto-annotation functions are now checked
4+
for compatibility with the function's and the task's label specs
5+
(<https://github.com/cvat-ai/cvat/pull/9005>)

cvat-sdk/cvat_sdk/auto_annotation/driver.py

+54-26
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,28 @@ class _AnnotationMapper:
5656
class _LabelIdMapping:
5757
id: int
5858
sublabels: Mapping[int, Optional[_AnnotationMapper._SublabelIdMapping]]
59-
expected_num_elements: int = 0
59+
expected_num_elements: int
60+
expected_type: str
6061

6162
_SpecIdMapping: TypeAlias = Mapping[int, Optional[_LabelIdMapping]]
6263

6364
_spec_id_mapping: _SpecIdMapping
6465

66+
def _get_expected_function_output_type(self, fun_label, ds_label):
67+
fun_output_type = getattr(fun_label, "type", "any")
68+
if fun_output_type == "any":
69+
return ds_label.type
70+
71+
if self._conv_mask_to_poly and fun_output_type == "mask":
72+
fun_output_type = "polygon"
73+
74+
if not self._are_label_types_compatible(fun_output_type, ds_label.type):
75+
raise BadFunctionError(
76+
f"label {fun_label.name!r} has type {fun_output_type!r} in the function,"
77+
f" but {ds_label.type!r} in the dataset"
78+
)
79+
return fun_output_type
80+
6581
def _build_label_id_mapping(
6682
self,
6783
fun_label: models.ILabel,
@@ -70,36 +86,37 @@ def _build_label_id_mapping(
7086
label_nm: _LabelNameMapping,
7187
allow_unmatched_labels: bool,
7288
) -> Optional[_LabelIdMapping]:
73-
sl_map = {}
74-
75-
if getattr(fun_label, "sublabels", []):
76-
ds_sublabels_by_name = {ds_sl.name: ds_sl for ds_sl in ds_label.sublabels}
77-
78-
def sublabel_mapping(fun_sl: models.ILabel) -> Optional[int]:
79-
sublabel_nm = label_nm.map_sublabel(fun_sl.name)
80-
if sublabel_nm is None:
81-
return None
82-
83-
ds_sl = ds_sublabels_by_name.get(sublabel_nm.name)
84-
if not ds_sl:
85-
if not allow_unmatched_labels:
86-
raise BadFunctionError(
87-
f"sublabel {fun_sl.name!r} of label {fun_label.name!r} is not in dataset"
88-
)
89-
90-
self._logger.info(
91-
"sublabel %r of label %r is not in dataset; any annotations using it will be ignored",
92-
fun_sl.name,
93-
fun_label.name,
89+
ds_sublabels_by_name = {ds_sl.name: ds_sl for ds_sl in ds_label.sublabels}
90+
91+
def sublabel_mapping(fun_sl: models.ILabel) -> Optional[int]:
92+
sublabel_nm = label_nm.map_sublabel(fun_sl.name)
93+
if sublabel_nm is None:
94+
return None
95+
96+
ds_sl = ds_sublabels_by_name.get(sublabel_nm.name)
97+
if not ds_sl:
98+
if not allow_unmatched_labels:
99+
raise BadFunctionError(
100+
f"sublabel {fun_sl.name!r} of label {fun_label.name!r} is not in dataset"
94101
)
95-
return None
96102

97-
return ds_sl.id
103+
self._logger.info(
104+
"sublabel %r of label %r is not in dataset; any annotations using it will be ignored",
105+
fun_sl.name,
106+
fun_label.name,
107+
)
108+
return None
98109

99-
sl_map = {fun_sl.id: sublabel_mapping(fun_sl) for fun_sl in fun_label.sublabels}
110+
return ds_sl.id
100111

101112
return self._LabelIdMapping(
102-
ds_label.id, sublabels=sl_map, expected_num_elements=len(ds_label.sublabels)
113+
ds_label.id,
114+
sublabels={
115+
fun_sl.id: sublabel_mapping(fun_sl)
116+
for fun_sl in getattr(fun_label, "sublabels", [])
117+
},
118+
expected_num_elements=len(ds_label.sublabels),
119+
expected_type=self._get_expected_function_output_type(fun_label, ds_label),
103120
)
104121

105122
def _build_spec_id_mapping(
@@ -254,6 +271,12 @@ def _remap_shape(self, shape: models.LabeledShapeRequest, ds_frame: int) -> bool
254271

255272
shape.label_id = label_id_mapping.id
256273

274+
if not self._are_label_types_compatible(shape.type.value, label_id_mapping.expected_type):
275+
raise BadFunctionError(
276+
f"function output shape of type {shape.type.value!r}"
277+
f" (expected {label_id_mapping.expected_type!r})"
278+
)
279+
257280
if shape.type.value == "mask" and self._conv_mask_to_poly:
258281
raise BadFunctionError("function output mask shape despite conv_mask_to_poly=True")
259282

@@ -269,6 +292,11 @@ def _remap_shape(self, shape: models.LabeledShapeRequest, ds_frame: int) -> bool
269292
def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame: int) -> None:
270293
shapes[:] = [shape for shape in shapes if self._remap_shape(shape, ds_frame)]
271294

295+
@staticmethod
296+
def _are_label_types_compatible(source_type: str, destination_type: str) -> bool:
297+
assert source_type != "any"
298+
return destination_type == "any" or destination_type == source_type
299+
272300

273301
@attrs.frozen(kw_only=True)
274302
class _DetectionFunctionContextImpl(DetectionFunctionContext):

cvat-sdk/cvat_sdk/auto_annotation/functions/_torchvision.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111

1212
class TorchvisionFunction:
13+
_label_type = "any"
14+
1315
def __init__(self, model_name: str, weights_name: str = "DEFAULT", **kwargs) -> None:
1416
weights_enum = torchvision.models.get_model_weights(model_name)
1517
self._weights = weights_enum[weights_name]
@@ -21,7 +23,7 @@ def __init__(self, model_name: str, weights_name: str = "DEFAULT", **kwargs) ->
2123
def spec(self) -> cvataa.DetectionFunctionSpec:
2224
return cvataa.DetectionFunctionSpec(
2325
labels=[
24-
cvataa.label_spec(cat, i)
26+
cvataa.label_spec(cat, i, type=self._label_type)
2527
for i, cat in enumerate(self._weights.meta["categories"])
2628
if cat != "N/A"
2729
]

cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212

1313
class _TorchvisionDetectionFunction(TorchvisionFunction):
14+
_label_type = "rectangle"
15+
1416
def detect(
1517
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
1618
) -> list[models.LabeledShapeRequest]:

cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_instance_segmentation.py

+2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def _generate_shapes(
5050

5151

5252
class _TorchvisionInstanceSegmentationFunction(TorchvisionFunction):
53+
_label_type = "mask"
54+
5355
def detect(
5456
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
5557
) -> list[models.LabeledShapeRequest]:

site/content/en/docs/api_sdk/sdk/auto-annotation.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@ class TorchvisionDetectionFunction:
6363
# describe the annotations
6464
return cvataa.DetectionFunctionSpec(
6565
labels=[
66-
cvataa.label_spec(cat, i)
67-
for i, cat in enumerate(self._weights.meta['categories'])
66+
cvataa.label_spec(cat, i, type="rectangle")
67+
for i, cat in enumerate(self._weights.meta["categories"])
68+
if cat != "N/A"
6869
]
6970
)
7071

0 commit comments

Comments
 (0)