|
1 | 1 | import pytest
|
| 2 | +from azure.ai.ml.dsl._group_decorator import group |
2 | 3 | from devtools_testutils import AzureRecordedTestCase, is_live
|
3 | 4 | from test_utilities.utils import _PYTEST_TIMEOUT_METHOD, assert_job_cancel, omit_with_wildcard
|
4 | 5 |
|
5 |
| -from azure.ai.ml import Input, MLClient, load_component |
| 6 | +from azure.ai.ml import Input, MLClient, load_component, Output |
6 | 7 | from azure.ai.ml.dsl import pipeline
|
7 | 8 | from azure.ai.ml.dsl._condition import condition
|
8 | 9 | from azure.ai.ml.dsl._do_while import do_while
|
@@ -171,12 +172,36 @@ def test_registered_component_is_control(self, client: MLClient):
|
171 | 172 | registered_component = client.components.create_or_update(primitive_component_with_normal_input_output_v2)
|
172 | 173 | rest_dict = registered_component._to_dict()
|
173 | 174 | # Assert is_control with correct bool type
|
174 |
| - assert rest_dict["outputs"] == { |
| 175 | + expected_dict = { |
175 | 176 | "output_data": {"type": "uri_folder"},
|
176 | 177 | "bool_param_output": {"type": "boolean", "is_control": True, "early_available": True},
|
177 | 178 | "int_param_output": {"type": "integer", "is_control": True},
|
178 | 179 | "float_param_output": {"type": "number", "is_control": True},
|
179 | 180 | "str_param_output": {"type": "string", "is_control": True}}
|
| 181 | + assert rest_dict["outputs"] == expected_dict |
| 182 | + |
| 183 | + # Assert on pipeline component |
| 184 | + @group |
| 185 | + class ControlOutputGroup: |
| 186 | + output_data: Output(type="uri_folder") |
| 187 | + float_param_output: Output(type="number", is_control=True) |
| 188 | + int_param_output: Output(type="integer", is_control=True) |
| 189 | + bool_param_output: Output(type="boolean", is_control=True) |
| 190 | + str_param_output: Output(type="string", is_control=True) |
| 191 | + |
| 192 | + @pipeline() |
| 193 | + def test_pipeline_component_control_output() -> ControlOutputGroup: |
| 194 | + node = primitive_component_with_normal_input_output_v2( |
| 195 | + input_data=test_input, parambool=True, |
| 196 | + paramint=2, paramfloat=2.2, paramstr="test" |
| 197 | + ) |
| 198 | + return node.outputs |
| 199 | + |
| 200 | + registered_pipeline_component = client.components.create_or_update(test_pipeline_component_control_output) |
| 201 | + rest_dict = registered_pipeline_component._to_dict() |
| 202 | + # Update expected dict, early_available will be removed for subgraph output. |
| 203 | + expected_dict["bool_param_output"] = {"type": "boolean", "is_control": True} |
| 204 | + assert rest_dict["outputs"] == expected_dict |
180 | 205 |
|
181 | 206 | def test_do_while_combined_if_else(self, client: MLClient):
|
182 | 207 | do_while_body_component = load_component(
|
|
0 commit comments