5
5
# pylint: disable=protected-access,redefined-builtin
6
6
7
7
from abc import ABC
8
- from typing import List
8
+ from typing import List , Dict , Union
9
9
10
10
from azure .ai .ml ._azure_environments import _get_active_directory_url_from_metadata
11
11
from azure .ai .ml ._utils .utils import camel_to_snake , snake_to_pascal
12
- from azure .ai .ml .entities ._mixins import RestTranslatableMixin , DictMixin
12
+ from azure .ai .ml .constants ._common import CommonYamlFields , IdentityType
13
+ from azure .ai .ml .entities ._mixins import RestTranslatableMixin , DictMixin , YamlTranslatableMixin
13
14
from azure .ai .ml ._restclient .v2022_05_01 .models import (
14
15
AccountKeyDatastoreCredentials as RestAccountKeyDatastoreCredentials ,
15
16
AccountKeyDatastoreSecrets as RestAccountKeyDatastoreSecrets ,
46
47
47
48
from azure .ai .ml ._restclient .v2022_10_01_preview .models import IdentityConfiguration as RestJobIdentityConfiguration
48
49
49
- from azure .ai .ml .exceptions import ErrorTarget , ErrorCategory , JobException
50
+ from azure .ai .ml .exceptions import ErrorTarget , ErrorCategory , JobException , ValidationErrorType , ValidationException
50
51
51
52
from azure .ai .ml ._restclient .v2022_05_01 .models import (
52
53
ManagedServiceIdentity as RestManagedServiceIdentityConfiguration ,
@@ -318,7 +319,7 @@ def __ne__(self, other: object) -> bool:
318
319
return not self .__eq__ (other )
319
320
320
321
321
- class _BaseJobIdentityConfiguration (ABC , RestTranslatableMixin , DictMixin ):
322
+ class _BaseJobIdentityConfiguration (ABC , RestTranslatableMixin , DictMixin , YamlTranslatableMixin ):
322
323
def __init__ (self ):
323
324
self .type = None
324
325
@@ -342,6 +343,29 @@ def _from_rest_object(cls, obj: RestJobIdentityConfiguration) -> "Identity":
342
343
error_category = ErrorCategory .SYSTEM_ERROR ,
343
344
)
344
345
346
+ @classmethod
347
+ def _load (
348
+ cls ,
349
+ data : Dict = None ,
350
+ ) -> Union ["ManagedIdentityConfiguration" , "UserIdentityConfiguration" , "AmlTokenConfiguration" ]:
351
+ type_str = data .get (CommonYamlFields .TYPE )
352
+ if type_str == IdentityType .MANAGED_IDENTITY :
353
+ identity_cls = ManagedIdentityConfiguration
354
+ elif type_str == IdentityType .USER_IDENTITY :
355
+ identity_cls = UserIdentityConfiguration
356
+ elif type_str == IdentityType .AML_TOKEN :
357
+ identity_cls = AmlTokenConfiguration
358
+ else :
359
+ msg = f"Unsupported identity type: { type_str } ."
360
+ raise ValidationException (
361
+ message = msg ,
362
+ no_personal_data_message = msg ,
363
+ target = ErrorTarget .IDENTITY ,
364
+ error_category = ErrorCategory .USER_ERROR ,
365
+ error_type = ValidationErrorType .INVALID_VALUE ,
366
+ )
367
+ return identity_cls ._load_from_dict (data )
368
+
345
369
346
370
class ManagedIdentityConfiguration (_BaseIdentityConfiguration ):
347
371
"""Managed Identity Credentials.
@@ -356,7 +380,7 @@ def __init__(
356
380
self , * , client_id : str = None , resource_id : str = None , object_id : str = None , principal_id : str = None
357
381
):
358
382
super ().__init__ ()
359
- self .type = camel_to_snake ( ConnectionAuthType .MANAGED_IDENTITY )
383
+ self .type = IdentityType .MANAGED_IDENTITY
360
384
self .client_id = client_id
361
385
# TODO: Check if both client_id and resource_id are required
362
386
self .resource_id = resource_id
@@ -418,6 +442,19 @@ def _from_workspace_rest_object(cls, obj: RestUserAssignedIdentityConfiguration)
418
442
client_id = obj .client_id ,
419
443
)
420
444
445
+ def _to_dict (self ) -> Dict :
446
+ # pylint: disable=no-member
447
+ from azure .ai .ml ._schema .job .identity import ManagedIdentitySchema
448
+
449
+ return ManagedIdentitySchema ().dump (self )
450
+
451
+ @classmethod
452
+ def _load_from_dict (cls , data : Dict ) -> "ManagedIdentityConfiguration" :
453
+ # pylint: disable=no-member
454
+ from azure .ai .ml ._schema .job .identity import ManagedIdentitySchema
455
+
456
+ return ManagedIdentitySchema ().load (data )
457
+
421
458
def __eq__ (self , other : object ) -> bool :
422
459
if not isinstance (other , ManagedIdentityConfiguration ):
423
460
return NotImplemented
@@ -429,7 +466,7 @@ class UserIdentityConfiguration(_BaseIdentityConfiguration):
429
466
430
467
def __init__ (self ):
431
468
super ().__init__ ()
432
- self .type = camel_to_snake ( IdentityConfigurationType .USER_IDENTITY )
469
+ self .type = IdentityType .USER_IDENTITY
433
470
434
471
# pylint: disable=no-self-use
435
472
def _to_job_rest_object (self ) -> RestUserIdentity :
@@ -440,6 +477,19 @@ def _to_job_rest_object(self) -> RestUserIdentity:
440
477
def _from_job_rest_object (cls , obj : RestUserIdentity ) -> "UserIdentity" :
441
478
return cls ()
442
479
480
+ def _to_dict (self ) -> Dict :
481
+ # pylint: disable=no-member
482
+ from azure .ai .ml ._schema .job .identity import UserIdentitySchema
483
+
484
+ return UserIdentitySchema ().dump (self )
485
+
486
+ @classmethod
487
+ def _load_from_dict (cls , data : Dict ) -> "UserIdentityConfiguration" :
488
+ # pylint: disable=no-member
489
+ from azure .ai .ml ._schema .job .identity import UserIdentitySchema
490
+
491
+ return UserIdentitySchema ().load (data )
492
+
443
493
def __eq__ (self , other : object ) -> bool :
444
494
if not isinstance (other , UserIdentityConfiguration ):
445
495
return NotImplemented
@@ -451,12 +501,25 @@ class AmlTokenConfiguration(_BaseIdentityConfiguration):
451
501
452
502
def __init__ (self ):
453
503
super ().__init__ ()
454
- self .type = camel_to_snake ( IdentityConfigurationType .AML_TOKEN )
504
+ self .type = IdentityType .AML_TOKEN
455
505
456
506
# pylint: disable=no-self-use
457
507
def _to_job_rest_object (self ) -> RestAmlToken :
458
508
return RestAmlToken ()
459
509
510
+ def _to_dict (self ) -> Dict :
511
+ # pylint: disable=no-member
512
+ from azure .ai .ml ._schema .job .identity import AMLTokenIdentitySchema
513
+
514
+ return AMLTokenIdentitySchema ().dump (self )
515
+
516
+ @classmethod
517
+ def _load_from_dict (cls , data : Dict ) -> "AMLTokenIdentitySchema" :
518
+ # pylint: disable=no-member
519
+ from azure .ai .ml ._schema .job .identity import AMLTokenIdentitySchema
520
+
521
+ return AMLTokenIdentitySchema ().load (data )
522
+
460
523
@classmethod
461
524
# pylint: disable=unused-argument
462
525
def _from_job_rest_object (cls , obj : RestAmlToken ) -> "AmlTokenConfiguration" :
0 commit comments