Skip to content

Commit 86684fb

Browse files
authored
Support identity in pipeline node (#26975)
* support identity in node * add test * update * fix pylint error * fix test
1 parent fa63b99 commit 86684fb

13 files changed

+1498
-333
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/pipeline/component_job.py

+7
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,13 @@ class CommandSchema(BaseNodeSchema, ParameterizedCommandSchema):
144144
],
145145
)
146146
services = fields.Dict(keys=fields.Str(), values=NestedField(JobServiceSchema))
147+
identity = UnionField(
148+
[
149+
NestedField(ManagedIdentitySchema),
150+
NestedField(AMLTokenIdentitySchema),
151+
NestedField(UserIdentitySchema),
152+
]
153+
)
147154

148155
@post_load
149156
def make(self, data, **kwargs) -> "Command":

sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py

+6
Original file line numberDiff line numberDiff line change
@@ -571,3 +571,9 @@ class RollingRate:
571571
class Scope:
572572
SUBSCRIPTION="subscription"
573573
RESOURCE_GROUP="resource_group"
574+
575+
576+
class IdentityType:
577+
AML_TOKEN = "aml_token"
578+
USER_IDENTITY = "user_identity"
579+
MANAGED_IDENTITY = "managed_identity"

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/command.py

+5
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@ def _to_rest_object(self, **kwargs) -> dict:
474474
"limits": get_rest_dict_for_node_attrs(self.limits, clear_empty_value=True),
475475
"resources": get_rest_dict_for_node_attrs(self.resources, clear_empty_value=True),
476476
"services": get_rest_dict_for_node_attrs(self.services),
477+
"identity": self.identity._to_dict() if self.identity else None,
477478
}.items():
478479
if value is not None:
479480
rest_obj[key] = value
@@ -529,6 +530,9 @@ def _from_rest_object(cls, obj: dict) -> "Command":
529530
rest_limits = RestCommandJobLimits.from_dict(obj["limits"])
530531
obj["limits"] = CommandJobLimits()._from_rest_object(rest_limits)
531532

533+
if "identity" in obj and obj["identity"]:
534+
obj["identity"] = _BaseJobIdentityConfiguration._load(obj["identity"])
535+
532536
return Command(**obj)
533537

534538
@classmethod
@@ -621,6 +625,7 @@ def __call__(self, *args, **kwargs) -> "Command":
621625
node.distribution = copy.deepcopy(self.distribution)
622626
node.resources = copy.deepcopy(self.resources)
623627
node.services = copy.deepcopy(self.services)
628+
node.identity = copy.deepcopy(self.identity)
624629
return node
625630
msg = "Command can be called as a function only when referenced component is {}, currently got {}."
626631
raise ValidationException(

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_credentials.py

+70-7
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
# pylint: disable=protected-access,redefined-builtin
66

77
from abc import ABC
8-
from typing import List
8+
from typing import List, Dict, Union
99

1010
from azure.ai.ml._azure_environments import _get_active_directory_url_from_metadata
1111
from azure.ai.ml._utils.utils import camel_to_snake, snake_to_pascal
12-
from azure.ai.ml.entities._mixins import RestTranslatableMixin, DictMixin
12+
from azure.ai.ml.constants._common import CommonYamlFields, IdentityType
13+
from azure.ai.ml.entities._mixins import RestTranslatableMixin, DictMixin, YamlTranslatableMixin
1314
from azure.ai.ml._restclient.v2022_05_01.models import (
1415
AccountKeyDatastoreCredentials as RestAccountKeyDatastoreCredentials,
1516
AccountKeyDatastoreSecrets as RestAccountKeyDatastoreSecrets,
@@ -46,7 +47,7 @@
4647

4748
from azure.ai.ml._restclient.v2022_10_01_preview.models import IdentityConfiguration as RestJobIdentityConfiguration
4849

49-
from azure.ai.ml.exceptions import ErrorTarget, ErrorCategory, JobException
50+
from azure.ai.ml.exceptions import ErrorTarget, ErrorCategory, JobException, ValidationErrorType, ValidationException
5051

5152
from azure.ai.ml._restclient.v2022_05_01.models import (
5253
ManagedServiceIdentity as RestManagedServiceIdentityConfiguration,
@@ -318,7 +319,7 @@ def __ne__(self, other: object) -> bool:
318319
return not self.__eq__(other)
319320

320321

321-
class _BaseJobIdentityConfiguration(ABC, RestTranslatableMixin, DictMixin):
322+
class _BaseJobIdentityConfiguration(ABC, RestTranslatableMixin, DictMixin, YamlTranslatableMixin):
322323
def __init__(self):
323324
self.type = None
324325

@@ -342,6 +343,29 @@ def _from_rest_object(cls, obj: RestJobIdentityConfiguration) -> "Identity":
342343
error_category=ErrorCategory.SYSTEM_ERROR,
343344
)
344345

346+
@classmethod
347+
def _load(
348+
cls,
349+
data: Dict = None,
350+
) -> Union["ManagedIdentityConfiguration", "UserIdentityConfiguration", "AmlTokenConfiguration"]:
351+
type_str = data.get(CommonYamlFields.TYPE)
352+
if type_str == IdentityType.MANAGED_IDENTITY:
353+
identity_cls = ManagedIdentityConfiguration
354+
elif type_str == IdentityType.USER_IDENTITY:
355+
identity_cls = UserIdentityConfiguration
356+
elif type_str == IdentityType.AML_TOKEN:
357+
identity_cls = AmlTokenConfiguration
358+
else:
359+
msg = f"Unsupported identity type: {type_str}."
360+
raise ValidationException(
361+
message=msg,
362+
no_personal_data_message=msg,
363+
target=ErrorTarget.IDENTITY,
364+
error_category=ErrorCategory.USER_ERROR,
365+
error_type=ValidationErrorType.INVALID_VALUE,
366+
)
367+
return identity_cls._load_from_dict(data)
368+
345369

346370
class ManagedIdentityConfiguration(_BaseIdentityConfiguration):
347371
"""Managed Identity Credentials.
@@ -356,7 +380,7 @@ def __init__(
356380
self, *, client_id: str = None, resource_id: str = None, object_id: str = None, principal_id: str = None
357381
):
358382
super().__init__()
359-
self.type = camel_to_snake(ConnectionAuthType.MANAGED_IDENTITY)
383+
self.type = IdentityType.MANAGED_IDENTITY
360384
self.client_id = client_id
361385
# TODO: Check if both client_id and resource_id are required
362386
self.resource_id = resource_id
@@ -418,6 +442,19 @@ def _from_workspace_rest_object(cls, obj: RestUserAssignedIdentityConfiguration)
418442
client_id=obj.client_id,
419443
)
420444

445+
def _to_dict(self) -> Dict:
446+
# pylint: disable=no-member
447+
from azure.ai.ml._schema.job.identity import ManagedIdentitySchema
448+
449+
return ManagedIdentitySchema().dump(self)
450+
451+
@classmethod
452+
def _load_from_dict(cls, data: Dict) -> "ManagedIdentityConfiguration":
453+
# pylint: disable=no-member
454+
from azure.ai.ml._schema.job.identity import ManagedIdentitySchema
455+
456+
return ManagedIdentitySchema().load(data)
457+
421458
def __eq__(self, other: object) -> bool:
422459
if not isinstance(other, ManagedIdentityConfiguration):
423460
return NotImplemented
@@ -429,7 +466,7 @@ class UserIdentityConfiguration(_BaseIdentityConfiguration):
429466

430467
def __init__(self):
431468
super().__init__()
432-
self.type = camel_to_snake(IdentityConfigurationType.USER_IDENTITY)
469+
self.type = IdentityType.USER_IDENTITY
433470

434471
# pylint: disable=no-self-use
435472
def _to_job_rest_object(self) -> RestUserIdentity:
@@ -440,6 +477,19 @@ def _to_job_rest_object(self) -> RestUserIdentity:
440477
def _from_job_rest_object(cls, obj: RestUserIdentity) -> "UserIdentity":
441478
return cls()
442479

480+
def _to_dict(self) -> Dict:
481+
# pylint: disable=no-member
482+
from azure.ai.ml._schema.job.identity import UserIdentitySchema
483+
484+
return UserIdentitySchema().dump(self)
485+
486+
@classmethod
487+
def _load_from_dict(cls, data: Dict) -> "UserIdentityConfiguration":
488+
# pylint: disable=no-member
489+
from azure.ai.ml._schema.job.identity import UserIdentitySchema
490+
491+
return UserIdentitySchema().load(data)
492+
443493
def __eq__(self, other: object) -> bool:
444494
if not isinstance(other, UserIdentityConfiguration):
445495
return NotImplemented
@@ -451,12 +501,25 @@ class AmlTokenConfiguration(_BaseIdentityConfiguration):
451501

452502
def __init__(self):
453503
super().__init__()
454-
self.type = camel_to_snake(IdentityConfigurationType.AML_TOKEN)
504+
self.type = IdentityType.AML_TOKEN
455505

456506
# pylint: disable=no-self-use
457507
def _to_job_rest_object(self) -> RestAmlToken:
458508
return RestAmlToken()
459509

510+
def _to_dict(self) -> Dict:
511+
# pylint: disable=no-member
512+
from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema
513+
514+
return AMLTokenIdentitySchema().dump(self)
515+
516+
@classmethod
517+
def _load_from_dict(cls, data: Dict) -> "AMLTokenIdentitySchema":
518+
# pylint: disable=no-member
519+
from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema
520+
521+
return AMLTokenIdentitySchema().load(data)
522+
460523
@classmethod
461524
# pylint: disable=unused-argument
462525
def _from_job_rest_object(cls, obj: RestAmlToken) -> "AmlTokenConfiguration":

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/command_job.py

+1
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def _to_node(self, context: Dict = None, **kwargs):
261261
limits=self.limits,
262262
services=self.services,
263263
properties=self.properties,
264+
identity=self.identity,
264265
)
265266

266267
def _validate(self) -> None:

sdk/ml/azure-ai-ml/tests/dsl/e2etests/test_dsl_pipeline.py

+80-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
TensorFlowDistribution,
1717
command,
1818
dsl,
19-
load_component,
19+
load_component, AmlTokenConfiguration, UserIdentityConfiguration, ManagedIdentityConfiguration,
2020
)
2121
from azure.ai.ml._utils._arm_id_utils import is_ARM_id_for_resource
2222
from azure.ai.ml.constants._common import AssetTypes, InputOutputModes
@@ -2448,3 +2448,82 @@ def pipeline_with_default_component():
24482448
pipeline_job = client.jobs.create_or_update(pipeline_with_default_component())
24492449
created_pipeline_job: PipelineJob = client.jobs.get(pipeline_job.name)
24502450
assert created_pipeline_job.jobs["node1"].component == f"{component_name}@default"
2451+
2452+
def test_pipeline_node_identity_with_component(self, client: MLClient):
2453+
path = "./tests/test_configs/components/helloworld_component.yml"
2454+
component_func = load_component(path)
2455+
2456+
@dsl.pipeline
2457+
def pipeline_func(component_in_path):
2458+
node1 = component_func(
2459+
component_in_number=1, component_in_path=component_in_path
2460+
)
2461+
node1.identity = AmlTokenConfiguration()
2462+
2463+
node2 = component_func(
2464+
component_in_number=1, component_in_path=component_in_path
2465+
)
2466+
node2.identity = UserIdentityConfiguration()
2467+
2468+
node3 = component_func(
2469+
component_in_number=1, component_in_path=component_in_path
2470+
)
2471+
node3.identity = ManagedIdentityConfiguration()
2472+
2473+
pipeline = pipeline_func(component_in_path=job_input)
2474+
pipeline_job = client.jobs.create_or_update(pipeline, compute="cpu-cluster")
2475+
omit_fields = [
2476+
"jobs.*.componentId",
2477+
"jobs.*._source"
2478+
]
2479+
actual_dict = omit_with_wildcard(pipeline_job._to_rest_object().as_dict()["properties"], *omit_fields)
2480+
assert actual_dict["jobs"] == {
2481+
'node1': {'computeId': None,
2482+
'display_name': None,
2483+
'distribution': None,
2484+
'environment_variables': {},
2485+
'identity': {'type': 'aml_token'},
2486+
'inputs': {'component_in_number': {'job_input_type': 'literal',
2487+
'value': '1'},
2488+
'component_in_path': {'job_input_type': 'literal',
2489+
'value': '${{parent.inputs.component_in_path}}'}},
2490+
'limits': None,
2491+
'name': 'node1',
2492+
'outputs': {},
2493+
'properties': {},
2494+
'resources': None,
2495+
'tags': {},
2496+
'type': 'command'},
2497+
'node2': {'computeId': None,
2498+
'display_name': None,
2499+
'distribution': None,
2500+
'environment_variables': {},
2501+
'identity': {'type': 'user_identity'},
2502+
'inputs': {'component_in_number': {'job_input_type': 'literal',
2503+
'value': '1'},
2504+
'component_in_path': {'job_input_type': 'literal',
2505+
'value': '${{parent.inputs.component_in_path}}'}},
2506+
'limits': None,
2507+
'name': 'node2',
2508+
'outputs': {},
2509+
'properties': {},
2510+
'resources': None,
2511+
'tags': {},
2512+
'type': 'command'},
2513+
'node3': {'computeId': None,
2514+
'display_name': None,
2515+
'distribution': None,
2516+
'environment_variables': {},
2517+
'identity': {'type': 'managed_identity'},
2518+
'inputs': {'component_in_number': {'job_input_type': 'literal',
2519+
'value': '1'},
2520+
'component_in_path': {'job_input_type': 'literal',
2521+
'value': '${{parent.inputs.component_in_path}}'}},
2522+
'limits': None,
2523+
'name': 'node3',
2524+
'outputs': {},
2525+
'properties': {},
2526+
'resources': None,
2527+
'tags': {},
2528+
'type': 'command'}
2529+
}

sdk/ml/azure-ai-ml/tests/dsl/unittests/test_command_builder.py

+48-1
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
command,
1414
load_component,
1515
load_job,
16-
spark,
16+
spark, UserIdentityConfiguration,
1717
)
18+
from azure.ai.ml.dsl import pipeline
1819
from azure.ai.ml.entities import CommandJobLimits, JobResourceConfiguration
1920
from azure.ai.ml.entities._builders import Command
2021
from azure.ai.ml.entities._job.job_service import JobService as JobService
2122
from azure.ai.ml.entities._job.pipeline._component_translatable import ComponentTranslatableMixin
2223
from azure.ai.ml.exceptions import JobException, ValidationException
24+
from test_utilities.utils import omit_with_wildcard
2325

2426
from .._util import _DSL_TIMEOUT_SECOND
2527

@@ -954,3 +956,48 @@ def test_command_hash(self, test_command_params):
954956

955957
node5 = command(**test_command_params, is_deterministic=False)
956958
assert hash(node1) != hash(node5)
959+
960+
def test_pipeline_node_identity_with_builder(self, test_command_params):
961+
test_command_params["identity"] = UserIdentityConfiguration()
962+
command_node = command(**test_command_params)
963+
rest_dict = command_node._to_rest_object()
964+
assert rest_dict["identity"] == {'type': 'user_identity'}
965+
966+
@pipeline
967+
def my_pipeline():
968+
command_node()
969+
970+
pipeline_job = my_pipeline()
971+
omit_fields = [
972+
"jobs.*.componentId",
973+
"jobs.*._source"
974+
]
975+
actual_dict = omit_with_wildcard(pipeline_job._to_rest_object().as_dict()["properties"], *omit_fields)
976+
977+
assert actual_dict["jobs"] == {
978+
'my_job': {'computeId': 'cpu-cluster',
979+
'display_name': 'my-fancy-job',
980+
'distribution': {'distribution_type': 'Mpi',
981+
'process_count_per_instance': 4},
982+
'environment_variables': {'foo': 'bar'},
983+
'identity': {'type': 'user_identity'},
984+
'inputs': {'boolean': {'job_input_type': 'literal',
985+
'value': 'False'},
986+
'float': {'job_input_type': 'literal', 'value': '0.01'},
987+
'integer': {'job_input_type': 'literal', 'value': '1'},
988+
'string': {'job_input_type': 'literal', 'value': 'str'},
989+
'uri_file': {'job_input_type': 'uri_file',
990+
'mode': 'Download',
991+
'uri': 'https://my-blob/path/to/data'},
992+
'uri_folder': {'job_input_type': 'uri_folder',
993+
'mode': 'ReadOnlyMount',
994+
'uri': 'https://my-blob/path/to/data'}},
995+
'limits': None,
996+
'name': 'my_job',
997+
'outputs': {'my_model': {'job_output_type': 'mlflow_model',
998+
'mode': 'ReadWriteMount'}},
999+
'properties': {},
1000+
'resources': None,
1001+
'tags': {},
1002+
'type': 'command'}
1003+
}

0 commit comments

Comments
 (0)