Skip to content

Commit 92f5938

Browse files
authored
[ML][Pipeline] Fix primitive pipeline component output (Azure#28808)
* Support primitive output on sub graph Signed-off-by: Brynn Yin <[email protected]> * Add test Signed-off-by: Brynn Yin <[email protected]> --------- Signed-off-by: Brynn Yin <[email protected]>
1 parent 93c72bd commit 92f5938

File tree

5 files changed

+554
-92
lines changed

5 files changed

+554
-92
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/input_output.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
"String",
1818
"float",
1919
"Float",
20+
"double",
21+
"Double"
2022
]
2123

2224

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
5+
# hack: map internal component output type to valid v2 output type
6+
from azure.ai.ml._internal._schema.input_output import SUPPORTED_INTERNAL_PARAM_TYPES
7+
from azure.ai.ml._utils.utils import get_all_enum_values_iter
8+
from azure.ai.ml.constants import AssetTypes
9+
from azure.ai.ml.constants._common import InputTypes
10+
11+
12+
def _map_internal_output_type(_meta):
13+
"""Map component output type to valid pipeline output type."""
14+
def _map_primitive_type(_type):
15+
"""Convert double and float to number type."""
16+
_type = _type.lower()
17+
if _type in ["double", "float"]:
18+
return InputTypes.NUMBER
19+
return _type
20+
21+
if type(_meta).__name__ != "InternalOutput":
22+
return _meta.type
23+
if _meta.type in list(get_all_enum_values_iter(AssetTypes)):
24+
return _meta.type
25+
if _meta.type in SUPPORTED_INTERNAL_PARAM_TYPES:
26+
return _map_primitive_type(_meta.type)
27+
if _meta.type in ["AnyFile"]:
28+
return AssetTypes.URI_FILE
29+
# Handle AnyDirectory and the other types.
30+
return AssetTypes.URI_FOLDER

sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_pipeline_component_builder.py

+3-13
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,12 @@
1111
from inspect import Parameter, signature
1212
from typing import Callable, Union
1313

14+
from azure.ai.ml._internal._utils._utils import _map_internal_output_type
1415
from azure.ai.ml._utils._func_utils import get_outputs_and_locals
1516
from azure.ai.ml._utils.utils import (
16-
get_all_enum_values_iter,
1717
is_valid_node_name,
1818
parse_args_description_from_docstring,
1919
)
20-
from azure.ai.ml.constants import AssetTypes
2120
from azure.ai.ml.constants._component import ComponentSource, IOConstants
2221
from azure.ai.ml.constants._job.pipeline import COMPONENT_IO_KEYWORDS
2322
from azure.ai.ml.dsl._utils import _sanitize_python_variable_name
@@ -259,19 +258,10 @@ def _build_pipeline_outputs(self, outputs: typing.Dict[str, NodeOutput]):
259258
is_control=value.is_control,
260259
)
261260

262-
# hack: map component output type to valid pipeline output type
263-
def _map_type(_meta):
264-
if type(_meta).__name__ != "InternalOutput":
265-
return _meta.type
266-
if _meta.type in list(get_all_enum_values_iter(AssetTypes)):
267-
return _meta.type
268-
if _meta.type in ["AnyFile"]:
269-
return AssetTypes.URI_FILE
270-
return AssetTypes.URI_FOLDER
271-
272261
# Note: Here we set PipelineOutput as Pipeline's output definition as we need output binding.
273262
output_meta = Output(
274-
type=_map_type(meta), description=meta.description, mode=meta.mode, is_control=meta.is_control
263+
type=_map_internal_output_type(meta), description=meta.description,
264+
mode=meta.mode, is_control=meta.is_control
275265
)
276266
pipeline_output = PipelineOutput(
277267
port_name=key,

sdk/ml/azure-ai-ml/tests/dsl/e2etests/test_controlflow_pipeline.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import pytest
2+
from azure.ai.ml.dsl._group_decorator import group
23
from devtools_testutils import AzureRecordedTestCase, is_live
34
from test_utilities.utils import _PYTEST_TIMEOUT_METHOD, assert_job_cancel, omit_with_wildcard
45

5-
from azure.ai.ml import Input, MLClient, load_component
6+
from azure.ai.ml import Input, MLClient, load_component, Output
67
from azure.ai.ml.dsl import pipeline
78
from azure.ai.ml.dsl._condition import condition
89
from azure.ai.ml.dsl._do_while import do_while
@@ -171,12 +172,36 @@ def test_registered_component_is_control(self, client: MLClient):
171172
registered_component = client.components.create_or_update(primitive_component_with_normal_input_output_v2)
172173
rest_dict = registered_component._to_dict()
173174
# Assert is_control with correct bool type
174-
assert rest_dict["outputs"] == {
175+
expected_dict = {
175176
"output_data": {"type": "uri_folder"},
176177
"bool_param_output": {"type": "boolean", "is_control": True, "early_available": True},
177178
"int_param_output": {"type": "integer", "is_control": True},
178179
"float_param_output": {"type": "number", "is_control": True},
179180
"str_param_output": {"type": "string", "is_control": True}}
181+
assert rest_dict["outputs"] == expected_dict
182+
183+
# Assert on pipeline component
184+
@group
185+
class ControlOutputGroup:
186+
output_data: Output(type="uri_folder")
187+
float_param_output: Output(type="number", is_control=True)
188+
int_param_output: Output(type="integer", is_control=True)
189+
bool_param_output: Output(type="boolean", is_control=True)
190+
str_param_output: Output(type="string", is_control=True)
191+
192+
@pipeline()
193+
def test_pipeline_component_control_output() -> ControlOutputGroup:
194+
node = primitive_component_with_normal_input_output_v2(
195+
input_data=test_input, parambool=True,
196+
paramint=2, paramfloat=2.2, paramstr="test"
197+
)
198+
return node.outputs
199+
200+
registered_pipeline_component = client.components.create_or_update(test_pipeline_component_control_output)
201+
rest_dict = registered_pipeline_component._to_dict()
202+
# Update expected dict, early_available will be removed for subgraph output.
203+
expected_dict["bool_param_output"] = {"type": "boolean", "is_control": True}
204+
assert rest_dict["outputs"] == expected_dict
180205

181206
def test_do_while_combined_if_else(self, client: MLClient):
182207
do_while_body_component = load_component(

0 commit comments

Comments
 (0)