Skip to content

feat: avoid serialize input value from int string to float #26878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from azure.ai.ml._schema import StringTransformedEnum, UnionField, PatchedSchemaMeta
from azure.ai.ml._schema.component.input_output import InputPortSchema, ParameterSchema
from azure.ai.ml._schema.core.fields import DumpableFloatField, DumpableIntegerField, DumpableEnumField

from azure.ai.ml._schema.core.fields import DumpableEnumField
from azure.ai.ml._schema.job.input_output_fields_provider import PrimitiveValueField

SUPPORTED_INTERNAL_PARAM_TYPES = [
"integer",
Expand Down Expand Up @@ -74,29 +74,9 @@ class InternalEnumParameterSchema(ParameterSchema):
required=True,
data_key="type",
)
default = UnionField(
[
DumpableIntegerField(strict=True),
# Use DumpableFloatField to avoid '1'(str) serialized to 1.0(float)
DumpableFloatField(),
# put string schema after Int and Float to make sure they won't dump to string
fields.Str(),
# fields.Bool comes last since it'll parse anything non-falsy to True
fields.Bool(),
],
)
default = PrimitiveValueField()
enum = fields.List(
UnionField(
[
DumpableIntegerField(strict=True),
# Use DumpableFloatField to avoid '1'(str) serialized to 1.0(float)
DumpableFloatField(),
# put string schema after Int and Float to make sure they won't dump to string
fields.Str(),
# fields.Bool comes last since it'll parse anything non-falsy to True
fields.Bool(),
]
),
PrimitiveValueField(),
required=True,
)

Expand Down
16 changes: 12 additions & 4 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,15 +675,23 @@ def _deserialize(self, value, attr, data, **kwargs) -> str:
class DumpableIntegerField(fields.Integer):
def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, _T]]:
if self.strict and not isinstance(value, int):
raise ValidationError("Given value is not an integer")
# this implementation can serialize bool to bool
raise self.make_error("invalid", input=value)
return super()._serialize(value, attr, obj, **kwargs)


class DumpableFloatField(fields.Float):
def __init__(self, *, strict: bool = False, allow_nan: bool = False, as_string: bool = False, **kwargs):
self.strict = strict
super().__init__(allow_nan=allow_nan, as_string=as_string, **kwargs)

def _validated(self, value):
if self.strict and not isinstance(value, float):
raise self.make_error("invalid", input=value)
return super()._validated(value)

def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, _T]]:
if not isinstance(value, float):
raise ValidationError("Given value is not a float")
return super()._serialize(value, attr, obj, **kwargs)
return super()._serialize(self._validated(value), attr, obj, **kwargs)


class DumpableStringField(fields.String):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,7 @@ def InputsField(**kwargs):
NestedField(ModelInputSchema),
NestedField(MLTableInputSchema),
NestedField(InputLiteralValueSchema),
UnionField(
[
# Note: order matters here - to make sure value parsed correctly.
# By default when strict is false, marshmallow downcasts float to int.
# Setting it to true will throw a validation error when loading a float to int.
# https://github.com/marshmallow-code/marshmallow/pull/755
# Use DumpableIntegerField to make sure there will be validation error when
# loading/dumping a float to int.
DumpableIntegerField(strict=True),
# Use DumpableFloatField to avoid '1'(str) serialized to 1.0(float)
DumpableFloatField(),
# put string schema after Int and Float to make sure they won't dump to string
fields.Str(),
# fields.Bool comes last since it'll parse anything non-falsy to True
fields.Bool(),
],
is_strict=False,
),
PrimitiveValueField(is_strict=False),
# This ordering of types for the values keyword is intentional. The ordering of types
# determines what order schema values are matched and cast in. Changing the current ordering can
# result in values being mis-cast such as 1.0 translating into True.
Expand All @@ -59,3 +42,25 @@ def OutputsField(**kwargs):
metadata={"description": "Outputs of a job."},
**kwargs
)


def PrimitiveValueField(**kwargs):
return UnionField(
[
# Note: order matters here - to make sure value parsed correctly.
# By default when strict is false, marshmallow downcasts float to int.
# Setting it to true will throw a validation error when loading a float to int.
# https://github.com/marshmallow-code/marshmallow/pull/755
# Use DumpableIntegerField to make sure there will be validation error when
# loading/dumping a float to int.
# note that this field can serialize bool instance but cannot deserialize bool instance.
DumpableIntegerField(strict=True),
# Use DumpableFloatField with strict of True to avoid '1'(str) serialized to 1.0(float)
DumpableFloatField(strict=True),
# put string schema after Int and Float to make sure they won't dump to string
fields.Str(),
# fields.Bool comes last since it'll parse anything non-falsy to True
fields.Bool(),
],
**kwargs,
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
LearningRateScheduler,
StochasticOptimizer,
)
from azure.ai.ml._restclient.v2022_10_01_preview.models import JobService as RestJobService
from azure.ai.ml._utils.utils import camel_to_snake, dump_yaml_to_file, is_data_binding_expression, load_yaml
from azure.ai.ml.constants._common import ARM_ID_PREFIX
from azure.ai.ml.constants._component import ComponentJobConstants
Expand Down Expand Up @@ -1729,6 +1728,7 @@ def test_dump_pipeline_inputs(self):
"integer_input": 15,
"bool_input": False,
"string_input": "hello",
"string_integer_input": "43",
}

job = load_job(test_path, params_override=[{"inputs": expected_inputs}])
Expand Down