Skip to content

Commit 5d05ccb

Browse files
authored
feat: avoid serialize input value from int string to float (#26878)
* fix: avoid serialize input value from int string to float * fix: fix ci * fix: pylint
1 parent cece86a commit 5d05ccb

File tree

4 files changed

+40
-47
lines changed

4 files changed

+40
-47
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/input_output.py

+4-24
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
from azure.ai.ml._schema import StringTransformedEnum, UnionField, PatchedSchemaMeta
88
from azure.ai.ml._schema.component.input_output import InputPortSchema, ParameterSchema
9-
from azure.ai.ml._schema.core.fields import DumpableFloatField, DumpableIntegerField, DumpableEnumField
10-
9+
from azure.ai.ml._schema.core.fields import DumpableEnumField
10+
from azure.ai.ml._schema.job.input_output_fields_provider import PrimitiveValueField
1111

1212
SUPPORTED_INTERNAL_PARAM_TYPES = [
1313
"integer",
@@ -74,29 +74,9 @@ class InternalEnumParameterSchema(ParameterSchema):
7474
required=True,
7575
data_key="type",
7676
)
77-
default = UnionField(
78-
[
79-
DumpableIntegerField(strict=True),
80-
# Use DumpableFloatField to avoid '1'(str) serialized to 1.0(float)
81-
DumpableFloatField(),
82-
# put string schema after Int and Float to make sure they won't dump to string
83-
fields.Str(),
84-
# fields.Bool comes last since it'll parse anything non-falsy to True
85-
fields.Bool(),
86-
],
87-
)
77+
default = PrimitiveValueField()
8878
enum = fields.List(
89-
UnionField(
90-
[
91-
DumpableIntegerField(strict=True),
92-
# Use DumpableFloatField to avoid '1'(str) serialized to 1.0(float)
93-
DumpableFloatField(),
94-
# put string schema after Int and Float to make sure they won't dump to string
95-
fields.Str(),
96-
# fields.Bool comes last since it'll parse anything non-falsy to True
97-
fields.Bool(),
98-
]
99-
),
79+
PrimitiveValueField(),
10080
required=True,
10181
)
10282

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -675,15 +675,23 @@ def _deserialize(self, value, attr, data, **kwargs) -> str:
675675
class DumpableIntegerField(fields.Integer):
676676
def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, _T]]:
677677
if self.strict and not isinstance(value, int):
678-
raise ValidationError("Given value is not an integer")
678+
# this implementation can serialize bool to bool
679+
raise self.make_error("invalid", input=value)
679680
return super()._serialize(value, attr, obj, **kwargs)
680681

681682

682683
class DumpableFloatField(fields.Float):
684+
def __init__(self, *, strict: bool = False, allow_nan: bool = False, as_string: bool = False, **kwargs):
685+
self.strict = strict
686+
super().__init__(allow_nan=allow_nan, as_string=as_string, **kwargs)
687+
688+
def _validated(self, value):
689+
if self.strict and not isinstance(value, float):
690+
raise self.make_error("invalid", input=value)
691+
return super()._validated(value)
692+
683693
def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, _T]]:
684-
if not isinstance(value, float):
685-
raise ValidationError("Given value is not a float")
686-
return super()._serialize(value, attr, obj, **kwargs)
694+
return super()._serialize(self._validated(value), attr, obj, **kwargs)
687695

688696

689697
class DumpableStringField(fields.String):

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/job/input_output_fields_provider.py

+23-18
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,7 @@ def InputsField(**kwargs):
2323
NestedField(ModelInputSchema),
2424
NestedField(MLTableInputSchema),
2525
NestedField(InputLiteralValueSchema),
26-
UnionField(
27-
[
28-
# Note: order matters here - to make sure value parsed correctly.
29-
# By default when strict is false, marshmallow downcasts float to int.
30-
# Setting it to true will throw a validation error when loading a float to int.
31-
# https://github.com/marshmallow-code/marshmallow/pull/755
32-
# Use DumpableIntegerField to make sure there will be validation error when
33-
# loading/dumping a float to int.
34-
DumpableIntegerField(strict=True),
35-
# Use DumpableFloatField to avoid '1'(str) serialized to 1.0(float)
36-
DumpableFloatField(),
37-
# put string schema after Int and Float to make sure they won't dump to string
38-
fields.Str(),
39-
# fields.Bool comes last since it'll parse anything non-falsy to True
40-
fields.Bool(),
41-
],
42-
is_strict=False,
43-
),
26+
PrimitiveValueField(is_strict=False),
4427
# This ordering of types for the values keyword is intentional. The ordering of types
4528
# determines what order schema values are matched and cast in. Changing the current ordering can
4629
# result in values being mis-cast such as 1.0 translating into True.
@@ -59,3 +42,25 @@ def OutputsField(**kwargs):
5942
metadata={"description": "Outputs of a job."},
6043
**kwargs
6144
)
45+
46+
47+
def PrimitiveValueField(**kwargs):
48+
return UnionField(
49+
[
50+
# Note: order matters here - to make sure value parsed correctly.
51+
# By default when strict is false, marshmallow downcasts float to int.
52+
# Setting it to true will throw a validation error when loading a float to int.
53+
# https://github.com/marshmallow-code/marshmallow/pull/755
54+
# Use DumpableIntegerField to make sure there will be validation error when
55+
# loading/dumping a float to int.
56+
# note that this field can serialize bool instance but cannot deserialize bool instance.
57+
DumpableIntegerField(strict=True),
58+
# Use DumpableFloatField with strict of True to avoid '1'(str) serialized to 1.0(float)
59+
DumpableFloatField(strict=True),
60+
# put string schema after Int and Float to make sure they won't dump to string
61+
fields.Str(),
62+
# fields.Bool comes last since it'll parse anything non-falsy to True
63+
fields.Bool(),
64+
],
65+
**kwargs,
66+
)

sdk/ml/azure-ai-ml/tests/pipeline_job/unittests/test_pipeline_job_schema.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
LearningRateScheduler,
2121
StochasticOptimizer,
2222
)
23-
from azure.ai.ml._restclient.v2022_10_01_preview.models import JobService as RestJobService
2423
from azure.ai.ml._utils.utils import camel_to_snake, dump_yaml_to_file, is_data_binding_expression, load_yaml
2524
from azure.ai.ml.constants._common import ARM_ID_PREFIX
2625
from azure.ai.ml.constants._component import ComponentJobConstants
@@ -1729,6 +1728,7 @@ def test_dump_pipeline_inputs(self):
17291728
"integer_input": 15,
17301729
"bool_input": False,
17311730
"string_input": "hello",
1731+
"string_integer_input": "43",
17321732
}
17331733

17341734
job = load_job(test_path, params_override=[{"inputs": expected_inputs}])

0 commit comments

Comments
 (0)