Skip to content

Commit 1d43ea6

Browse files
authored
fix: combo fix for internal components I (#26718)
* fix: support data-binding expression on internal node runsettings * fix: raise error if unknown field errors are met in load_funcs * fix: support is_control output for internal components * feat: ignore unknown fields on internal output port definition instead of raising error * feat: enable hdinsight.compute_name * feat: support dependency files in additional_includes * feat: convert hdinsight.compute_name to arm str * feat: load all internal nodes to their own entity class instead of base node * resolve comment add * to validation class method remove super class for InternalOutputSchema
1 parent e3415c4 commit 1d43ea6

File tree

67 files changed

+6209
-5802
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+6209
-5802
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/component.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
44

5-
from marshmallow import fields, post_dump
5+
from marshmallow import fields, post_dump, INCLUDE, EXCLUDE
66

77
from azure.ai.ml._schema import NestedField, StringTransformedEnum, UnionField
88
from azure.ai.ml._schema.component.component import ComponentSchema
@@ -15,6 +15,7 @@
1515
InternalInputPortSchema,
1616
InternalOutputPortSchema,
1717
InternalParameterSchema,
18+
InternalPrimitiveOutputSchema,
1819
)
1920

2021

@@ -42,6 +43,8 @@ def all_values(cls):
4243

4344

4445
class InternalBaseComponentSchema(ComponentSchema):
46+
class Meta:
47+
unknown = INCLUDE
4548
# override name as 1p components allow . in name, which is not allowed in v2 components
4649
name = fields.Str()
4750

@@ -60,7 +63,16 @@ class InternalBaseComponentSchema(ComponentSchema):
6063
]
6164
),
6265
)
63-
outputs = fields.Dict(keys=fields.Str(), values=NestedField(InternalOutputPortSchema))
66+
# support primitive output for all internal components for now
67+
outputs = fields.Dict(
68+
keys=fields.Str(),
69+
values=UnionField(
70+
[
71+
NestedField(InternalPrimitiveOutputSchema, unknown=EXCLUDE),
72+
NestedField(InternalOutputPortSchema, unknown=EXCLUDE),
73+
]
74+
),
75+
)
6476

6577
# type field is required for registration
6678
type = StringTransformedEnum(

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/input_output.py

+27-16
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,22 @@
44

55
from marshmallow import fields, post_dump, post_load
66

7-
from azure.ai.ml._schema import StringTransformedEnum, UnionField
8-
from azure.ai.ml._schema.component.input_output import InputPortSchema, OutputPortSchema, ParameterSchema
9-
from azure.ai.ml._schema.core.fields import DumpableFloatField, DumpableIntegerField
7+
from azure.ai.ml._schema import StringTransformedEnum, UnionField, PatchedSchemaMeta
8+
from azure.ai.ml._schema.component.input_output import InputPortSchema, ParameterSchema
9+
from azure.ai.ml._schema.core.fields import DumpableFloatField, DumpableIntegerField, DumpableEnumField
1010

1111

12+
SUPPORTED_INTERNAL_PARAM_TYPES = [
13+
"integer",
14+
"Integer",
15+
"boolean",
16+
"Boolean",
17+
"string",
18+
"String",
19+
"float",
20+
"Float",
21+
]
22+
1223
class InternalInputPortSchema(InputPortSchema):
1324
# skip client-side validate for type enum & support list
1425
type = UnionField(
@@ -29,29 +40,29 @@ def resolve_list_type(self, data, original_data, **kwargs): # pylint: disable=u
2940
return data
3041

3142

32-
class InternalOutputPortSchema(OutputPortSchema):
43+
class InternalOutputPortSchema(metaclass=PatchedSchemaMeta):
3344
# skip client-side validate for type enum
3445
type = fields.Str(
3546
required=True,
3647
data_key="type",
3748
)
49+
description = fields.Str()
3850
is_link_mode = fields.Bool()
3951
datastore_mode = fields.Str()
4052

4153

54+
class InternalPrimitiveOutputSchema(metaclass=PatchedSchemaMeta):
55+
type = DumpableEnumField(
56+
allowed_values=SUPPORTED_INTERNAL_PARAM_TYPES,
57+
required=True,
58+
)
59+
description = fields.Str()
60+
is_control = fields.Bool()
61+
62+
4263
class InternalParameterSchema(ParameterSchema):
43-
type = StringTransformedEnum(
44-
allowed_values=[
45-
"integer",
46-
"Integer",
47-
"boolean",
48-
"Boolean",
49-
"string",
50-
"String",
51-
"float",
52-
"Float",
53-
],
54-
casing_transform=lambda x: x,
64+
type = DumpableEnumField(
65+
allowed_values=SUPPORTED_INTERNAL_PARAM_TYPES,
5566
required=True,
5667
data_key="type",
5768
)

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/node.py

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313

1414
class InternalBaseNodeSchema(BaseNodeSchema):
15+
class Meta:
16+
unknown = INCLUDE
1517
component = UnionField(
1618
[
1719
# for registry type assets
@@ -56,6 +58,7 @@ class ScopeSchema(InternalBaseNodeSchema):
5658
class HDInsightSchema(InternalBaseNodeSchema):
5759
type = StringTransformedEnum(allowed_values=[NodeType.HDI], casing_transform=lambda x: x)
5860

61+
compute_name = fields.Str()
5962
queue = fields.Str()
6063
driver_memory = fields.Str()
6164
driver_cores = fields.Int()

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_util.py

+6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
InternalComponent,
1818
Parallel,
1919
Scope,
20+
DataTransfer,
21+
Hemera,
22+
Starlite,
2023
)
2124
from azure.ai.ml._schema import NestedField
2225
from azure.ai.ml.entities._component.component_factory import component_factory
@@ -57,6 +60,9 @@ def enable_internal_components_in_pipeline():
5760
_register_node(_type, InternalBaseNode, InternalBaseNodeSchema)
5861

5962
# redo the registration for those with specific runsettings
63+
_register_node(NodeType.DATA_TRANSFER, DataTransfer, InternalBaseNodeSchema)
64+
_register_node(NodeType.HEMERA, Hemera, InternalBaseNodeSchema)
65+
_register_node(NodeType.STARLITE, Starlite, InternalBaseNodeSchema)
6066
_register_node(NodeType.COMMAND, Command, CommandSchema)
6167
_register_node(NodeType.DISTRIBUTED, Distributed, DistributedSchema)
6268
_register_node(NodeType.SCOPE, Scope, ScopeSchema)

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/_additional_includes.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,34 @@
1616

1717
class _AdditionalIncludes:
1818
def __init__(self, code_path: Union[None, str], yaml_path: str):
19-
self._yaml_path = Path(yaml_path)
20-
self._yaml_name = self._yaml_path.name
21-
self._code_path = self._yaml_path.parent
22-
if code_path is not None:
23-
self._code_path = (self._code_path / code_path).resolve()
19+
self.__yaml_path = Path(yaml_path)
20+
self.__code_path = code_path
21+
2422
self._tmp_code_path = None
25-
self._additional_includes_file_path = self._yaml_path.with_suffix(f".{ADDITIONAL_INCLUDES_SUFFIX}")
2623
self._includes = None
2724
if self._additional_includes_file_path.is_file():
2825
with open(self._additional_includes_file_path, "r") as f:
2926
lines = f.readlines()
3027
self._includes = [line.strip() for line in lines if len(line.strip()) > 0]
3128

29+
@property
30+
def _yaml_path(self) -> Path:
31+
return self.__yaml_path
32+
33+
@property
34+
def _code_path(self) -> Path:
35+
if self.__code_path is not None:
36+
return (self._yaml_path.parent / self.__code_path).resolve()
37+
return self._yaml_path.parent
38+
39+
@property
40+
def _yaml_name(self) -> str:
41+
return self._yaml_path.name
42+
43+
@property
44+
def _additional_includes_file_path(self) -> Path:
45+
return self._yaml_path.with_suffix(f".{ADDITIONAL_INCLUDES_SUFFIX}")
46+
3247
@property
3348
def code(self) -> Path:
3449
return self._tmp_code_path if self._tmp_code_path else self._code_path

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/command.py

-4
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ def compute(self) -> str:
4242
@compute.setter
4343
def compute(self, value: str):
4444
"""Set the compute definition for the command."""
45-
if value is not None and not isinstance(value, str):
46-
raise ValueError(f"Failed in setting compute: only string is supported in DPv2 but got {type(value)}")
4745
self._compute = value
4846

4947
@property
@@ -54,8 +52,6 @@ def environment(self) -> str:
5452
@environment.setter
5553
def environment(self, value: str):
5654
"""Set the environment definition for the command."""
57-
if value is not None and not isinstance(value, str):
58-
raise ValueError(f"Failed in setting environment: only string is supported in DPv2 but got {type(value)}")
5955
self._environment = value
6056

6157
@property

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/component.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# disable redefined-builtin to use id/type as argument name
66
from contextlib import contextmanager
77
from typing import Dict, Union
8+
import os
89

910
from marshmallow import INCLUDE, Schema
1011

@@ -176,6 +177,21 @@ def _additional_includes(self):
176177
def _create_schema_for_validation(cls, context) -> Union[PathAwareSchema, Schema]:
177178
return InternalBaseComponentSchema(context=context)
178179

180+
def _validate(self, raise_error=False) -> MutableValidationResult:
181+
if self._additional_includes is not None and self._additional_includes._validate().passed:
182+
# update source path in case dependency file is in additional_includes
183+
with self._resolve_local_code() as tmp_base_path:
184+
origin_base_path, origin_source_path = self._base_path, self._source_path
185+
186+
try:
187+
self._base_path, self._source_path = \
188+
tmp_base_path, tmp_base_path / os.path.basename(self._source_path)
189+
return super()._validate(raise_error=raise_error)
190+
finally:
191+
self._base_path, self._source_path = origin_base_path, origin_source_path
192+
193+
return super()._validate(raise_error=raise_error)
194+
179195
def _customized_validate(self) -> MutableValidationResult:
180196
validation_result = super(InternalComponent, self)._customized_validate()
181197
if isinstance(self.environment, InternalEnvironment):
@@ -228,14 +244,3 @@ def _resolve_local_code(self):
228244

229245
def __call__(self, *args, **kwargs) -> InternalBaseNode: # pylint: disable=useless-super-delegation
230246
return super(InternalComponent, self).__call__(*args, **kwargs)
231-
232-
def _schema_validate(self) -> MutableValidationResult:
233-
"""Validate the resource with the schema.
234-
235-
return type: ValidationResult
236-
"""
237-
result = super(InternalComponent, self)._schema_validate()
238-
# skip unknown field warnings for internal components
239-
# TODO: move this logic into base class
240-
result._warnings = list(filter(lambda x: x.message != "Unknown field.", result._warnings))
241-
return result

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/node.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from azure.ai.ml.entities._job.pipeline._io import NodeInput, NodeOutput, PipelineInput
1717
from azure.ai.ml.entities._util import convert_ordered_dict_to_dict
1818

19-
from ...entities._validation import MutableValidationResult
2019
from .._schema.component import NodeType
2120
from ._input_outputs import InternalInput
2221

@@ -95,17 +94,6 @@ def _to_job(self) -> Job:
9594
def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs) -> "Job":
9695
raise RuntimeError("Internal components doesn't support load from dict")
9796

98-
def _schema_validate(self) -> MutableValidationResult:
99-
"""Validate the resource with the schema.
100-
101-
return type: ValidationResult
102-
"""
103-
result = super(InternalBaseNode, self)._schema_validate()
104-
# skip unknown field warnings for internal components
105-
# TODO: move this logic into base class?
106-
result._warnings = list(filter(lambda x: x.message != "Unknown field.", result._warnings))
107-
return result
108-
10997
@classmethod
11098
def _create_schema_for_validation(cls, context) -> Union[PathAwareSchema, Schema]:
11199
from .._schema.node import InternalBaseNodeSchema
@@ -176,6 +164,7 @@ def __init__(self, **kwargs):
176164
kwargs.pop("type", None)
177165
super(HDInsight, self).__init__(type=NodeType.HDI, **kwargs)
178166
self._init = True
167+
self._compute_name: str = kwargs.pop("compute_name", None)
179168
self._queue: str = kwargs.pop("queue", None)
180169
self._driver_memory: str = kwargs.pop("driver_memory", None)
181170
self._driver_cores: int = kwargs.pop("driver_cores", None)
@@ -186,6 +175,15 @@ def __init__(self, **kwargs):
186175
self._hdinsight_spark_job_name: str = kwargs.pop("hdinsight_spark_job_name", None)
187176
self._init = False
188177

178+
@property
179+
def compute_name(self) -> str:
180+
"""Name of the compute to be used."""
181+
return self._compute_name
182+
183+
@compute_name.setter
184+
def compute_name(self, value: str):
185+
self._compute_name = value
186+
189187
@property
190188
def queue(self) -> str:
191189
"""The name of the YARN queue to which submitted."""
@@ -267,6 +265,7 @@ def hdinsight_spark_job_name(self, value: str):
267265
@classmethod
268266
def _picked_fields_from_dict_to_rest_object(cls) -> List[str]:
269267
return [
268+
"compute_name",
270269
"queue",
271270
"driver_cores",
272271
"executor_memory",

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/component/command_component.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class Meta:
3838
],
3939
metadata={"description": "Provides the configuration for a distributed run."},
4040
)
41-
# primitive output is only supported for command component
41+
# primitive output is only supported for command component & pipeline component
4242
outputs = fields.Dict(
4343
keys=fields.Str(),
4444
values=UnionField(

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/component/input_output.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,12 @@ class OutputPortSchema(metaclass=PatchedSchemaMeta):
5151
)
5252

5353

54-
class PrimitiveOutputSchema(metaclass=PatchedSchemaMeta):
54+
class PrimitiveOutputSchema(OutputPortSchema):
5555
type = DumpableEnumField(
5656
allowed_values=SUPPORTED_PARAM_TYPES,
5757
required=True,
5858
)
59-
description = fields.Str()
6059
is_control = fields.Bool()
61-
mode = DumpableEnumField(
62-
allowed_values=SUPPORTED_INPUT_OUTPUT_MODES,
63-
)
6460

6561

6662
class ParameterSchema(metaclass=PatchedSchemaMeta):

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/schema_meta.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@ def __new__(cls, name, bases, dct):
4444
if meta is None:
4545
dct["Meta"] = PatchedMeta
4646
else:
47-
dct["Meta"].unknown = RAISE
48-
dct["Meta"].ordered = True
47+
if not hasattr(meta, "unknown"):
48+
dct["Meta"].unknown = RAISE
49+
if not hasattr(meta, "ordered"):
50+
dct["Meta"].ordered = True
4951

5052
bases = bases + (PatchedBaseSchema,)
5153
klass = super().__new__(cls, name, bases, dct)

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/pipeline/pipeline_component.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ class PipelineComponentSchema(ComponentSchema):
136136
type = StringTransformedEnum(allowed_values=[NodeType.PIPELINE])
137137
jobs = PipelineJobsField()
138138

139-
# primitive output is only supported for command component
139+
# primitive output is only supported for command component & pipeline component
140140
outputs = fields.Dict(
141141
keys=fields.Str(),
142142
values=UnionField(

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_load_functions.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def load_common(
8989
return _load_common_raising_marshmallow_error(cls, yaml_dict, relative_origin, params_override, **kwargs)
9090
except ValidationError as e:
9191
if issubclass(cls, SchemaValidatableMixin):
92-
validation_result = _ValidationResultBuilder.from_validation_error(e, relative_origin)
92+
validation_result = _ValidationResultBuilder.from_validation_error(e, source_path=relative_origin)
9393
validation_result.try_raise(
9494
# pylint: disable=protected-access
9595
error_target=cls._get_validation_error_target(),
@@ -102,8 +102,7 @@ def load_common(
102102
f"of type {type_str}, please specify the correct "
103103
f"type in the 'type' property.",
104104
)
105-
else:
106-
raise e
105+
raise e
107106

108107

109108
def _try_load_yaml_dict(source: Union[str, PathLike, IO[AnyStr]]) -> dict:

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_validate_funcs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def validate_common(cls, path, validate_func, params_override=None) -> Validatio
3030
except ValidationException as err:
3131
return _ValidationResultBuilder.from_single_message(err.message)
3232
except ValidationError as err:
33-
return _ValidationResultBuilder.from_validation_error(err, path)
33+
return _ValidationResultBuilder.from_validation_error(err, source_path=path)
3434

3535

3636
def validate_component(path, ml_client=None, params_override=None) -> ValidationResult:

0 commit comments

Comments
 (0)