Skip to content

Support identity in pipeline node #26975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ class CommandSchema(BaseNodeSchema, ParameterizedCommandSchema):
],
)
services = fields.Dict(keys=fields.Str(), values=NestedField(JobServiceSchema))
identity = UnionField(
[
NestedField(ManagedIdentitySchema),
NestedField(AMLTokenIdentitySchema),
NestedField(UserIdentitySchema),
]
)

@post_load
def make(self, data, **kwargs) -> "Command":
Expand Down
6 changes: 6 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,3 +571,9 @@ class RollingRate:
class Scope:
SUBSCRIPTION="subscription"
RESOURCE_GROUP="resource_group"


class IdentityType:
AML_TOKEN = "aml_token"
USER_IDENTITY = "user_identity"
MANAGED_IDENTITY = "managed_identity"
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ def _to_rest_object(self, **kwargs) -> dict:
"limits": get_rest_dict_for_node_attrs(self.limits, clear_empty_value=True),
"resources": get_rest_dict_for_node_attrs(self.resources, clear_empty_value=True),
"services": get_rest_dict_for_node_attrs(self.services),
"identity": self.identity._to_dict() if self.identity else None,
}.items():
if value is not None:
rest_obj[key] = value
Expand Down Expand Up @@ -529,6 +530,9 @@ def _from_rest_object(cls, obj: dict) -> "Command":
rest_limits = RestCommandJobLimits.from_dict(obj["limits"])
obj["limits"] = CommandJobLimits()._from_rest_object(rest_limits)

if "identity" in obj and obj["identity"]:
obj["identity"] = _BaseJobIdentityConfiguration._load(obj["identity"])

return Command(**obj)

@classmethod
Expand Down Expand Up @@ -621,6 +625,7 @@ def __call__(self, *args, **kwargs) -> "Command":
node.distribution = copy.deepcopy(self.distribution)
node.resources = copy.deepcopy(self.resources)
node.services = copy.deepcopy(self.services)
node.identity = copy.deepcopy(self.identity)
return node
msg = "Command can be called as a function only when referenced component is {}, currently got {}."
raise ValidationException(
Expand Down
77 changes: 70 additions & 7 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
# pylint: disable=protected-access,redefined-builtin

from abc import ABC
from typing import List
from typing import List, Dict, Union

from azure.ai.ml._azure_environments import _get_active_directory_url_from_metadata
from azure.ai.ml._utils.utils import camel_to_snake, snake_to_pascal
from azure.ai.ml.entities._mixins import RestTranslatableMixin, DictMixin
from azure.ai.ml.constants._common import CommonYamlFields, IdentityType
from azure.ai.ml.entities._mixins import RestTranslatableMixin, DictMixin, YamlTranslatableMixin
from azure.ai.ml._restclient.v2022_05_01.models import (
AccountKeyDatastoreCredentials as RestAccountKeyDatastoreCredentials,
AccountKeyDatastoreSecrets as RestAccountKeyDatastoreSecrets,
Expand Down Expand Up @@ -46,7 +47,7 @@

from azure.ai.ml._restclient.v2022_10_01_preview.models import IdentityConfiguration as RestJobIdentityConfiguration

from azure.ai.ml.exceptions import ErrorTarget, ErrorCategory, JobException
from azure.ai.ml.exceptions import ErrorTarget, ErrorCategory, JobException, ValidationErrorType, ValidationException

from azure.ai.ml._restclient.v2022_05_01.models import (
ManagedServiceIdentity as RestManagedServiceIdentityConfiguration,
Expand Down Expand Up @@ -318,7 +319,7 @@ def __ne__(self, other: object) -> bool:
return not self.__eq__(other)


class _BaseJobIdentityConfiguration(ABC, RestTranslatableMixin, DictMixin):
class _BaseJobIdentityConfiguration(ABC, RestTranslatableMixin, DictMixin, YamlTranslatableMixin):
def __init__(self):
self.type = None

Expand All @@ -342,6 +343,29 @@ def _from_rest_object(cls, obj: RestJobIdentityConfiguration) -> "Identity":
error_category=ErrorCategory.SYSTEM_ERROR,
)

@classmethod
def _load(
cls,
data: Dict = None,
) -> Union["ManagedIdentityConfiguration", "UserIdentityConfiguration", "AmlTokenConfiguration"]:
type_str = data.get(CommonYamlFields.TYPE)
if type_str == IdentityType.MANAGED_IDENTITY:
identity_cls = ManagedIdentityConfiguration
elif type_str == IdentityType.USER_IDENTITY:
identity_cls = UserIdentityConfiguration
elif type_str == IdentityType.AML_TOKEN:
identity_cls = AmlTokenConfiguration
else:
msg = f"Unsupported identity type: {type_str}."
raise ValidationException(
message=msg,
no_personal_data_message=msg,
target=ErrorTarget.IDENTITY,
error_category=ErrorCategory.USER_ERROR,
error_type=ValidationErrorType.INVALID_VALUE,
)
return identity_cls._load_from_dict(data)


class ManagedIdentityConfiguration(_BaseIdentityConfiguration):
"""Managed Identity Credentials.
Expand All @@ -356,7 +380,7 @@ def __init__(
self, *, client_id: str = None, resource_id: str = None, object_id: str = None, principal_id: str = None
):
super().__init__()
self.type = camel_to_snake(ConnectionAuthType.MANAGED_IDENTITY)
self.type = IdentityType.MANAGED_IDENTITY
self.client_id = client_id
# TODO: Check if both client_id and resource_id are required
self.resource_id = resource_id
Expand Down Expand Up @@ -418,6 +442,19 @@ def _from_workspace_rest_object(cls, obj: RestUserAssignedIdentityConfiguration)
client_id=obj.client_id,
)

def _to_dict(self) -> Dict:
# pylint: disable=no-member
from azure.ai.ml._schema.job.identity import ManagedIdentitySchema

return ManagedIdentitySchema().dump(self)

@classmethod
def _load_from_dict(cls, data: Dict) -> "ManagedIdentityConfiguration":
# pylint: disable=no-member
from azure.ai.ml._schema.job.identity import ManagedIdentitySchema

return ManagedIdentitySchema().load(data)

def __eq__(self, other: object) -> bool:
if not isinstance(other, ManagedIdentityConfiguration):
return NotImplemented
Expand All @@ -429,7 +466,7 @@ class UserIdentityConfiguration(_BaseIdentityConfiguration):

def __init__(self):
super().__init__()
self.type = camel_to_snake(IdentityConfigurationType.USER_IDENTITY)
self.type = IdentityType.USER_IDENTITY

# pylint: disable=no-self-use
def _to_job_rest_object(self) -> RestUserIdentity:
Expand All @@ -440,6 +477,19 @@ def _to_job_rest_object(self) -> RestUserIdentity:
def _from_job_rest_object(cls, obj: RestUserIdentity) -> "UserIdentity":
return cls()

def _to_dict(self) -> Dict:
# pylint: disable=no-member
from azure.ai.ml._schema.job.identity import UserIdentitySchema

return UserIdentitySchema().dump(self)

@classmethod
def _load_from_dict(cls, data: Dict) -> "UserIdentityConfiguration":
# pylint: disable=no-member
from azure.ai.ml._schema.job.identity import UserIdentitySchema

return UserIdentitySchema().load(data)

def __eq__(self, other: object) -> bool:
if not isinstance(other, UserIdentityConfiguration):
return NotImplemented
Expand All @@ -451,12 +501,25 @@ class AmlTokenConfiguration(_BaseIdentityConfiguration):

def __init__(self):
super().__init__()
self.type = camel_to_snake(IdentityConfigurationType.AML_TOKEN)
self.type = IdentityType.AML_TOKEN

# pylint: disable=no-self-use
def _to_job_rest_object(self) -> RestAmlToken:
return RestAmlToken()

def _to_dict(self) -> Dict:
# pylint: disable=no-member
from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema

return AMLTokenIdentitySchema().dump(self)

@classmethod
def _load_from_dict(cls, data: Dict) -> "AMLTokenIdentitySchema":
# pylint: disable=no-member
from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema

return AMLTokenIdentitySchema().load(data)

@classmethod
# pylint: disable=unused-argument
def _from_job_rest_object(cls, obj: RestAmlToken) -> "AmlTokenConfiguration":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def _to_node(self, context: Dict = None, **kwargs):
limits=self.limits,
services=self.services,
properties=self.properties,
identity=self.identity,
)

def _validate(self) -> None:
Expand Down
81 changes: 80 additions & 1 deletion sdk/ml/azure-ai-ml/tests/dsl/e2etests/test_dsl_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
TensorFlowDistribution,
command,
dsl,
load_component,
load_component, AmlTokenConfiguration, UserIdentityConfiguration, ManagedIdentityConfiguration,
)
from azure.ai.ml._utils._arm_id_utils import is_ARM_id_for_resource
from azure.ai.ml.constants._common import AssetTypes, InputOutputModes
Expand Down Expand Up @@ -2448,3 +2448,82 @@ def pipeline_with_default_component():
pipeline_job = client.jobs.create_or_update(pipeline_with_default_component())
created_pipeline_job: PipelineJob = client.jobs.get(pipeline_job.name)
assert created_pipeline_job.jobs["node1"].component == f"{component_name}@default"

def test_pipeline_node_identity_with_component(self, client: MLClient):
path = "./tests/test_configs/components/helloworld_component.yml"
component_func = load_component(path)

@dsl.pipeline
def pipeline_func(component_in_path):
node1 = component_func(
component_in_number=1, component_in_path=component_in_path
)
node1.identity = AmlTokenConfiguration()

node2 = component_func(
component_in_number=1, component_in_path=component_in_path
)
node2.identity = UserIdentityConfiguration()

node3 = component_func(
component_in_number=1, component_in_path=component_in_path
)
node3.identity = ManagedIdentityConfiguration()

pipeline = pipeline_func(component_in_path=job_input)
pipeline_job = client.jobs.create_or_update(pipeline, compute="cpu-cluster")
omit_fields = [
"jobs.*.componentId",
"jobs.*._source"
]
actual_dict = omit_with_wildcard(pipeline_job._to_rest_object().as_dict()["properties"], *omit_fields)
assert actual_dict["jobs"] == {
'node1': {'computeId': None,
'display_name': None,
'distribution': None,
'environment_variables': {},
'identity': {'type': 'aml_token'},
'inputs': {'component_in_number': {'job_input_type': 'literal',
'value': '1'},
'component_in_path': {'job_input_type': 'literal',
'value': '${{parent.inputs.component_in_path}}'}},
'limits': None,
'name': 'node1',
'outputs': {},
'properties': {},
'resources': None,
'tags': {},
'type': 'command'},
'node2': {'computeId': None,
'display_name': None,
'distribution': None,
'environment_variables': {},
'identity': {'type': 'user_identity'},
'inputs': {'component_in_number': {'job_input_type': 'literal',
'value': '1'},
'component_in_path': {'job_input_type': 'literal',
'value': '${{parent.inputs.component_in_path}}'}},
'limits': None,
'name': 'node2',
'outputs': {},
'properties': {},
'resources': None,
'tags': {},
'type': 'command'},
'node3': {'computeId': None,
'display_name': None,
'distribution': None,
'environment_variables': {},
'identity': {'type': 'managed_identity'},
'inputs': {'component_in_number': {'job_input_type': 'literal',
'value': '1'},
'component_in_path': {'job_input_type': 'literal',
'value': '${{parent.inputs.component_in_path}}'}},
'limits': None,
'name': 'node3',
'outputs': {},
'properties': {},
'resources': None,
'tags': {},
'type': 'command'}
}
49 changes: 48 additions & 1 deletion sdk/ml/azure-ai-ml/tests/dsl/unittests/test_command_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
command,
load_component,
load_job,
spark,
spark, UserIdentityConfiguration,
)
from azure.ai.ml.dsl import pipeline
from azure.ai.ml.entities import CommandJobLimits, JobResourceConfiguration
from azure.ai.ml.entities._builders import Command
from azure.ai.ml.entities._job.job_service import JobService as JobService
from azure.ai.ml.entities._job.pipeline._component_translatable import ComponentTranslatableMixin
from azure.ai.ml.exceptions import JobException, ValidationException
from test_utilities.utils import omit_with_wildcard

from .._util import _DSL_TIMEOUT_SECOND

Expand Down Expand Up @@ -954,3 +956,48 @@ def test_command_hash(self, test_command_params):

node5 = command(**test_command_params, is_deterministic=False)
assert hash(node1) != hash(node5)

def test_pipeline_node_identity_with_builder(self, test_command_params):
test_command_params["identity"] = UserIdentityConfiguration()
command_node = command(**test_command_params)
rest_dict = command_node._to_rest_object()
assert rest_dict["identity"] == {'type': 'user_identity'}

@pipeline
def my_pipeline():
command_node()

pipeline_job = my_pipeline()
omit_fields = [
"jobs.*.componentId",
"jobs.*._source"
]
actual_dict = omit_with_wildcard(pipeline_job._to_rest_object().as_dict()["properties"], *omit_fields)

assert actual_dict["jobs"] == {
'my_job': {'computeId': 'cpu-cluster',
'display_name': 'my-fancy-job',
'distribution': {'distribution_type': 'Mpi',
'process_count_per_instance': 4},
'environment_variables': {'foo': 'bar'},
'identity': {'type': 'user_identity'},
'inputs': {'boolean': {'job_input_type': 'literal',
'value': 'False'},
'float': {'job_input_type': 'literal', 'value': '0.01'},
'integer': {'job_input_type': 'literal', 'value': '1'},
'string': {'job_input_type': 'literal', 'value': 'str'},
'uri_file': {'job_input_type': 'uri_file',
'mode': 'Download',
'uri': 'https://my-blob/path/to/data'},
'uri_folder': {'job_input_type': 'uri_folder',
'mode': 'ReadOnlyMount',
'uri': 'https://my-blob/path/to/data'}},
'limits': None,
'name': 'my_job',
'outputs': {'my_model': {'job_output_type': 'mlflow_model',
'mode': 'ReadWriteMount'}},
'properties': {},
'resources': None,
'tags': {},
'type': 'command'}
}
Loading