Skip to content

Commit a17fed2

Browse files
authored
[ML][Pipelines]Fix if_else CLI error and add related tests (Azure#28252)
* add test for if_else cli * add tests
1 parent 83fca5e commit a17fed2

18 files changed

+3478
-473
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/pipeline/pipeline_component.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def _post_load_pipeline_jobs(context, data: dict) -> dict:
9393
from azure.ai.ml.entities._builders.parallel_for import ParallelFor
9494
from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob
9595
from azure.ai.ml.entities._job.pipeline._component_translatable import ComponentTranslatableMixin
96+
from azure.ai.ml.entities._builders.condition_node import ConditionNode
9697

9798
# parse inputs/outputs
9899
data = parse_inputs_outputs(data)
@@ -107,6 +108,10 @@ def _post_load_pipeline_jobs(context, data: dict) -> dict:
107108
loaded_data=job_instance,
108109
)
109110
jobs[key] = job_instance
111+
elif job_instance.get("type") == ControlFlowType.IF_ELSE:
112+
# Convert to if-else node.
113+
job_instance = ConditionNode._create_instance_from_schema_dict(loaded_data=job_instance)
114+
jobs[key] = job_instance
110115
elif job_instance.get("type") == ControlFlowType.DO_WHILE:
111116
# Convert to do-while node.
112117
job_instance = DoWhile._create_instance_from_schema_dict(pipeline_jobs=jobs, loaded_data=job_instance)

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/condition_node.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Dict
66

77
from azure.ai.ml._schema import PathAwareSchema
8+
from azure.ai.ml._utils.utils import is_data_binding_expression
89
from azure.ai.ml.constants._component import ControlFlowType
910
from azure.ai.ml.entities._builders import BaseNode
1011
from azure.ai.ml.entities._builders.control_flow_node import ControlFlowNode
@@ -35,6 +36,11 @@ def _create_schema_for_validation(cls, context) -> PathAwareSchema: # pylint: d
3536
def _from_rest_object(cls, obj: dict) -> "ConditionNode":
3637
return cls(**obj)
3738

39+
@classmethod
40+
def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "ConditionNode":
41+
"""Create a condition node instance from schema parsed dict."""
42+
return cls(**loaded_data)
43+
3844
def _to_dict(self) -> Dict:
3945
return self._dump_for_validation()
4046

@@ -63,16 +69,37 @@ def _validate_params(self, raise_error=True) -> MutableValidationResult:
6369
f"with value 'True', got {output_definition.is_control}",
6470
)
6571

66-
error_msg = "{!r} of dsl.condition node must be an instance of " f"{BaseNode} or {AutoMLJob}," "got {!r}."
67-
if self.true_block is not None and not isinstance(self.true_block, (BaseNode, AutoMLJob)):
72+
# check if condition is valid binding
73+
if isinstance(self.condition, str) and not is_data_binding_expression(
74+
self.condition, ["parent"], is_singular=False):
75+
error_tail = "for example, ${{parent.jobs.xxx.outputs.output}}"
76+
validation_result.append_error(
77+
yaml_path="condition",
78+
message=f"'condition' of dsl.condition has invalid binding expression: {self.condition}, {error_tail}",
79+
)
80+
81+
error_msg = "{!r} of dsl.condition node must be an instance of " \
82+
f"{BaseNode}, {AutoMLJob} or {str}," "got {!r}."
83+
if self.true_block is not None and not isinstance(self.true_block, (BaseNode, AutoMLJob, str)):
6884
validation_result.append_error(
6985
yaml_path="true_block", message=error_msg.format("true_block", type(self.true_block))
7086
)
71-
if self.false_block is not None and not isinstance(self.false_block, (BaseNode, AutoMLJob)):
87+
if self.false_block is not None and not isinstance(self.false_block, (BaseNode, AutoMLJob, str)):
7288
validation_result.append_error(
7389
yaml_path="false_block", message=error_msg.format("false_block", type(self.false_block))
7490
)
7591

92+
# check if true/false block is valid binding
93+
for name, block in {"true_block": self.true_block, "false_block": self.false_block}.items():
94+
if block is None or not isinstance(block, str):
95+
continue
96+
error_tail = "for example, ${{parent.jobs.xxx}}"
97+
if not is_data_binding_expression(block, ["parent", "jobs"], is_singular=False):
98+
validation_result.append_error(
99+
yaml_path=name,
100+
message=f"'{name}' of dsl.condition has invalid binding expression: {block}, {error_tail}",
101+
)
102+
76103
if self.true_block is None and self.false_block is None:
77104
validation_result.append_error(
78105
yaml_path="true_block",

sdk/ml/azure-ai-ml/tests/dsl/e2etests/test_controlflow_pipeline.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_dsl_condition_pipeline(self, client: MLClient):
5555
compute="cpu-cluster",
5656
)
5757
def condition_pipeline():
58-
result = basic_component(str_param="abc", int_param=1)
58+
result = basic_component()
5959

6060
node1 = hello_world_component_no_paths(component_in_number=1)
6161
node2 = hello_world_component_no_paths(component_in_number=2)
@@ -89,10 +89,6 @@ def condition_pipeline():
8989
},
9090
"result": {
9191
"_source": "REMOTE.WORKSPACE.COMPONENT",
92-
"inputs": {
93-
"int_param": {"job_input_type": "literal", "value": "1"},
94-
"str_param": {"job_input_type": "literal", "value": "abc"},
95-
},
9692
"name": "result",
9793
"type": "command",
9894
},

sdk/ml/azure-ai-ml/tests/pipeline_job/e2etests/test_control_flow_pipeline.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,28 @@
11
from typing import Callable
22

33
import pytest
4+
5+
from azure.ai.ml.exceptions import ValidationException
46
from devtools_testutils import AzureRecordedTestCase, is_live
57
from test_utilities.utils import _PYTEST_TIMEOUT_METHOD
68

79
from azure.ai.ml import MLClient, load_job
810
from azure.ai.ml._schema.pipeline import pipeline_job
9-
from azure.ai.ml._utils.utils import load_yaml
1011
from azure.ai.ml.entities._builders import Command, Pipeline
1112
from azure.ai.ml.entities._builders.do_while import DoWhile
1213
from azure.ai.ml.entities._builders.parallel_for import ParallelFor
1314

1415
from .._util import _PIPELINE_JOB_TIMEOUT_SECOND
1516
from .test_pipeline_job import assert_job_cancel
17+
from test_utilities.utils import omit_with_wildcard
18+
19+
omit_fields = [
20+
"name",
21+
"properties.display_name",
22+
"properties.settings",
23+
"properties.jobs.*._source",
24+
"properties.jobs.*.componentId",
25+
]
1626

1727

1828
@pytest.fixture()
@@ -44,6 +54,86 @@ class TestConditionalNodeInPipeline(AzureRecordedTestCase):
4454
pass
4555

4656

57+
class TestIfElse(TestConditionalNodeInPipeline):
58+
def test_happy_path_if_else(self, client: MLClient, randstr: Callable[[], str]) -> None:
59+
params_override = [{"name": randstr('name')}]
60+
my_job = load_job(
61+
"./tests/test_configs/pipeline_jobs/control_flow/if_else/simple_pipeline.yml",
62+
params_override=params_override,
63+
)
64+
created_pipeline = assert_job_cancel(my_job, client)
65+
66+
pipeline_job_dict = created_pipeline._to_rest_object().as_dict()
67+
68+
pipeline_job_dict = omit_with_wildcard(pipeline_job_dict, *omit_fields)
69+
assert pipeline_job_dict["properties"]["jobs"] == {
70+
'conditionnode': {'condition': '${{parent.jobs.result.outputs.output}}',
71+
'false_block': '${{parent.jobs.node1}}',
72+
'true_block': '${{parent.jobs.node2}}',
73+
'type': 'if_else'},
74+
'node1': {'inputs': {'component_in_number': {'job_input_type': 'literal',
75+
'value': '1'}},
76+
'name': 'node1',
77+
'type': 'command'},
78+
'node2': {'inputs': {'component_in_number': {'job_input_type': 'literal',
79+
'value': '2'}},
80+
'name': 'node2',
81+
'type': 'command'},
82+
'result': {'name': 'result', 'type': 'command'}
83+
}
84+
85+
def test_if_else_one_branch(self, client: MLClient, randstr: Callable[[], str]) -> None:
86+
params_override = [{"name": randstr('name')}]
87+
my_job = load_job(
88+
"./tests/test_configs/pipeline_jobs/control_flow/if_else/one_branch.yml",
89+
params_override=params_override,
90+
)
91+
created_pipeline = assert_job_cancel(my_job, client)
92+
93+
pipeline_job_dict = created_pipeline._to_rest_object().as_dict()
94+
95+
pipeline_job_dict = omit_with_wildcard(pipeline_job_dict, *omit_fields)
96+
assert pipeline_job_dict["properties"]["jobs"] == {
97+
'conditionnode': {'condition': '${{parent.jobs.result.outputs.output}}',
98+
'true_block': '${{parent.jobs.node1}}',
99+
'type': 'if_else'},
100+
'node1': {'inputs': {'component_in_number': {'job_input_type': 'literal',
101+
'value': '1'}},
102+
'name': 'node1',
103+
'type': 'command'},
104+
'result': {'name': 'result', 'type': 'command'}
105+
}
106+
107+
def test_if_else_literal_condition(self, client: MLClient, randstr: Callable[[], str]) -> None:
108+
params_override = [{"name": randstr('name')}]
109+
my_job = load_job(
110+
"./tests/test_configs/pipeline_jobs/control_flow/if_else/literal_condition.yml",
111+
params_override=params_override,
112+
)
113+
created_pipeline = assert_job_cancel(my_job, client)
114+
115+
pipeline_job_dict = created_pipeline._to_rest_object().as_dict()
116+
117+
pipeline_job_dict = omit_with_wildcard(pipeline_job_dict, *omit_fields)
118+
assert pipeline_job_dict["properties"]["jobs"] == {
119+
'conditionnode': {'condition': True,
120+
'true_block': '${{parent.jobs.node1}}',
121+
'type': 'if_else'},
122+
'node1': {'inputs': {'component_in_number': {'job_input_type': 'literal',
123+
'value': '1'}},
124+
'name': 'node1',
125+
'type': 'command'}
126+
}
127+
128+
def test_if_else_invalid_case(self, client: MLClient, randstr: Callable[[], str]) -> None:
129+
my_job = load_job(
130+
"./tests/test_configs/pipeline_jobs/control_flow/if_else/invalid_binding.yml",
131+
)
132+
with pytest.raises(ValidationException) as e:
133+
my_job._validate(raise_error=True)
134+
assert '"path": "jobs.conditionnode.true_block",' in str(e.value)
135+
assert "'true_block' of dsl.condition has invalid binding expression:" in str(e.value)
136+
47137
class TestDoWhile(TestConditionalNodeInPipeline):
48138
def test_pipeline_with_do_while_node(self, client: MLClient, randstr: Callable[[], str]) -> None:
49139
params_override = [{"name": randstr('name')}]

0 commit comments

Comments
 (0)