Skip to content

Commit ce859ba

Browse files
committed
fix: support is_control output for internal components
1 parent 799f975 commit ce859ba

File tree

10 files changed

+85
-23
lines changed

10 files changed

+85
-23
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/component.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
InternalInputPortSchema,
1616
InternalOutputPortSchema,
1717
InternalParameterSchema,
18+
InternalPrimitiveOutputSchema,
1819
)
1920

2021

@@ -60,7 +61,16 @@ class InternalBaseComponentSchema(ComponentSchema):
6061
]
6162
),
6263
)
63-
outputs = fields.Dict(keys=fields.Str(), values=NestedField(InternalOutputPortSchema))
64+
# support primitive output for all internal components for now
65+
outputs = fields.Dict(
66+
keys=fields.Str(),
67+
values=UnionField(
68+
[
69+
NestedField(InternalPrimitiveOutputSchema),
70+
NestedField(InternalOutputPortSchema),
71+
]
72+
),
73+
)
6474

6575
# type field is required for registration
6676
type = StringTransformedEnum(

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/input_output.py

+22-13
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,20 @@
66

77
from azure.ai.ml._schema import StringTransformedEnum, UnionField
88
from azure.ai.ml._schema.component.input_output import InputPortSchema, OutputPortSchema, ParameterSchema
9-
from azure.ai.ml._schema.core.fields import DumpableFloatField, DumpableIntegerField
9+
from azure.ai.ml._schema.core.fields import DumpableFloatField, DumpableIntegerField, DumpableEnumField
1010

1111

12+
SUPPORTED_INTERNAL_PARAM_TYPES = [
13+
"integer",
14+
"Integer",
15+
"boolean",
16+
"Boolean",
17+
"string",
18+
"String",
19+
"float",
20+
"Float",
21+
]
22+
1223
class InternalInputPortSchema(InputPortSchema):
1324
# skip client-side validate for type enum & support list
1425
type = UnionField(
@@ -39,19 +50,17 @@ class InternalOutputPortSchema(OutputPortSchema):
3950
datastore_mode = fields.Str()
4051

4152

53+
class InternalPrimitiveOutputSchema(OutputPortSchema):
54+
type = DumpableEnumField(
55+
allowed_values=SUPPORTED_INTERNAL_PARAM_TYPES,
56+
required=True,
57+
)
58+
is_control = fields.Bool()
59+
60+
4261
class InternalParameterSchema(ParameterSchema):
43-
type = StringTransformedEnum(
44-
allowed_values=[
45-
"integer",
46-
"Integer",
47-
"boolean",
48-
"Boolean",
49-
"string",
50-
"String",
51-
"float",
52-
"Float",
53-
],
54-
casing_transform=lambda x: x,
62+
type = DumpableEnumField(
63+
allowed_values=SUPPORTED_INTERNAL_PARAM_TYPES,
5564
required=True,
5665
data_key="type",
5766
)

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/component/command_component.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class Meta:
3838
],
3939
metadata={"description": "Provides the configuration for a distributed run."},
4040
)
41-
# primitive output is only supported for command component
41+
# primitive output is only supported for command component & pipeline component
4242
outputs = fields.Dict(
4343
keys=fields.Str(),
4444
values=UnionField(

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/component/input_output.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,12 @@ class OutputPortSchema(metaclass=PatchedSchemaMeta):
5151
)
5252

5353

54-
class PrimitiveOutputSchema(metaclass=PatchedSchemaMeta):
54+
class PrimitiveOutputSchema(OutputPortSchema):
5555
type = DumpableEnumField(
5656
allowed_values=SUPPORTED_PARAM_TYPES,
5757
required=True,
5858
)
59-
description = fields.Str()
6059
is_control = fields.Bool()
61-
mode = DumpableEnumField(
62-
allowed_values=SUPPORTED_INPUT_OUTPUT_MODES,
63-
)
6460

6561

6662
class ParameterSchema(metaclass=PatchedSchemaMeta):

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/pipeline/pipeline_component.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ class PipelineComponentSchema(ComponentSchema):
136136
type = StringTransformedEnum(allowed_values=[NodeType.PIPELINE])
137137
jobs = PipelineJobsField()
138138

139-
# primitive output is only supported for command component
139+
# primitive output is only supported for command component & pipeline component
140140
outputs = fields.Dict(
141141
keys=fields.Str(),
142142
values=UnionField(

sdk/ml/azure-ai-ml/tests/internal/unittests/test_component.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def test_invalid_additional_includes(self, yaml_path: str, expected_error_msg_pr
377377
assert validation_result.error_messages["*"].startswith(expected_error_msg_prefix)
378378

379379
def test_component_input_types(self) -> None:
380-
yaml_path = "./tests/test_configs/internal/component_with_input_types/component_spec.yaml"
380+
yaml_path = "./tests/test_configs/internal/component_with_input_outputs/component_spec.yaml"
381381
component: InternalComponent = load_component(yaml_path)
382382
component.code = "scope:1"
383383

@@ -403,7 +403,7 @@ def test_component_input_types(self) -> None:
403403
assert component._validate().passed is True, repr(component._validate())
404404

405405
def test_component_input_with_attrs(self) -> None:
406-
yaml_path = "./tests/test_configs/internal/component_with_input_types/component_spec_with_attrs.yaml"
406+
yaml_path = "./tests/test_configs/internal/component_with_input_outputs/component_spec_with_attrs.yaml"
407407
component: InternalComponent = load_component(source=yaml_path)
408408

409409
expected_inputs = {
@@ -433,6 +433,24 @@ def test_component_input_with_attrs(self) -> None:
433433
assert regenerated_component._to_rest_object().properties.component_spec["inputs"] == expected_inputs["inputs"]
434434
assert component._validate().passed is True, repr(component._validate())
435435

436+
def test_component_output_with_attrs(self) -> None:
437+
yaml_path = "./tests/test_configs/internal/component_with_input_outputs/component_spec_with_outputs.yaml"
438+
component: InternalComponent = load_component(source=yaml_path)
439+
assert component
440+
441+
expected_outputs = {
442+
"primitive_is_control": {
443+
"is_control": True,
444+
"type": "boolean",
445+
}
446+
}
447+
assert component._to_rest_object().properties.component_spec["outputs"] == expected_outputs
448+
assert component._validate().passed is True, repr(component._validate())
449+
450+
regenerated_component = Component._from_rest_object(component._to_rest_object())
451+
assert regenerated_component._to_rest_object().properties.component_spec["outputs"] == expected_outputs
452+
assert component._validate().passed is True, repr(component._validate())
453+
436454
def test_component_input_list_type(self) -> None:
437455
yaml_path = "./tests/test_configs/internal/scope-component/component_spec.yaml"
438456
component: InternalComponent = load_component(yaml_path)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
$schema: https://componentsdk.azureedge.net/jsonschema/CommandComponent.json
2+
name: convert_file_to_boolean
3+
version: 0.0.1
4+
display_name: Convert File to Boolean
5+
type: CommandComponent
6+
description: Convert input file to boolean type output.
7+
tags:
8+
codegenBy: dsl.component
9+
inputs:
10+
input:
11+
optional: false
12+
type:
13+
- AnyFile
14+
- AnyDirectory
15+
outputs:
16+
primitive_is_control:
17+
is_control: true
18+
type: boolean
19+
command: python -m azure.ml.component.dsl.executor --file run.py --name convert_file_to_boolean
20+
--params --input {inputs.input} --output {outputs.output}
21+
environment:
22+
name: AzureML-Component
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#DECLARE Output_stream string = @@Output_SSPath@@;
2+
#DECLARE In_Data string =@"@@Input_TextData@@";
3+
4+
RawData = EXTRACT @@ExtractionClause@@ FROM @In_Data
5+
USING DefaultTextExtractor();
6+
7+
OUTPUT RawData TO SSTREAM @Output_stream;

0 commit comments

Comments
 (0)