Skip to content

Add DictMixin to easy AML SDK classes #26365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 11, 2022
4 changes: 3 additions & 1 deletion sdk/ml/azure-ai-ml/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
### Other Changes
- Removed declaration on Python 3.6 support
- Added support for custom setup scripts on compute instances.
- Removed declaration on Python 3.6 support.
- Updated dependencies upper bounds to be major versions.

## 0.1.0b7 (In progress)
Expand All @@ -25,6 +26,7 @@
- Entity load and dump now also accept a file pointer as input.
- Load and dump input names changed from path to 'source' and 'dest', respectively.
- Load and dump 'path' input still works, but is deprecated and emits a warning.
- Most configuration classes from the entity package now implement the standard mapping protocol.
- Managed Identity Support for Compute Instance (experimental).
- Enable using @dsl.pipeline without brackets when no additional parameters.
- Expose Azure subscription Id and resource group name from MLClient objects.
Expand All @@ -35,6 +37,7 @@
- Remove invalid option from create_or_update typehints.
- Change error returned by (begin_)create_or_update invalid input to TypeError.
- Rename set_image_model APIs for all vision tasks to set_training_parameters
- JobOperations.download no longer provides a default value for download_path
- JobOperations.download defaults to "." instead of Path.cwd()
- Workspace.list_keys renamed to Workspace.get_keys.

Expand All @@ -43,7 +46,6 @@
### Other Changes
- Show 'properties' on data assets


## 0.1.0b6

### Features Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from azure.ai.ml._restclient.v2022_01_01_preview.models import Identity as RestIdentity
from azure.ai.ml._utils.utils import camel_to_snake, snake_to_pascal
from azure.ai.ml.entities._mixins import RestTranslatableMixin
from azure.ai.ml.entities._mixins import DictMixin

from ._user_assigned_identity import UserAssignedIdentity


class IdentityConfiguration(RestTranslatableMixin):
class IdentityConfiguration(RestTranslatableMixin, DictMixin):
"""Managed identity specification."""

def __init__(self, *, type: str, user_assigned_identities: List[UserAssignedIdentity] = None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
from azure.ai.ml.constants._compute import ComputeDefaults, ComputeType
from azure.ai.ml.entities._compute.compute import Compute, NetworkSettings
from azure.ai.ml.entities._mixins import DictMixin
from azure.ai.ml.entities._util import load_from_dict
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException

Expand Down Expand Up @@ -70,7 +71,7 @@ def ssh_port(self) -> str:
return self._ssh_port


class AssignedUserConfiguration:
class AssignedUserConfiguration(DictMixin):
"""Settings to create a compute on behalf of another user."""

def __init__(self, *, user_tenant_id: str, user_object_id: str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

from azure.ai.ml._restclient.v2021_10_01.models import CodeConfiguration as RestCodeConfiguration
from azure.ai.ml.entities._assets import Code
from azure.ai.ml.entities._mixins import DictMixin
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException

module_logger = logging.getLogger(__name__)


class CodeConfiguration:
class CodeConfiguration(DictMixin):
"""CodeConfiguration.

:param code: Code entity, defaults to None
Expand Down
2 changes: 1 addition & 1 deletion sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _from_rest_object(cls, obj: RestAmlToken) -> "AmlToken":
return cls()


class ManagedIdentity(Identity, DictMixin):
class ManagedIdentity(Identity):
"""Managed identity configuration.

:param client_id: Specifies a user-assigned identity by client ID. For system-assigned, do not
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ def _data_binding(self) -> str:
"""Return data binding string representation for this input/output."""
raise NotImplementedError()

# Why did we have this function? It prevents the DictMixin from being applied.
# Unclear if we explicitly do NOT want the mapping protocol to be applied to this, or it this was just
# confirmation that it didn't at the time.
def keys(self):
# This property is introduced to raise catchable Exception in marshmallow mapping validation trial.
raise TypeError(f"'{type(self).__name__}' object is not a mapping")
Expand Down
14 changes: 10 additions & 4 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ def _from_rest_object(cls, obj: Any) -> Any:


class DictMixin(object):
def __contains__(self, item):
return self.__dict__.__contains__(item)

def __iter__(self):
return self.__dict__.__iter__()

def __setitem__(self, key, item):
# type: (Any, Any) -> None
self.__dict__[key] = item
Expand All @@ -29,10 +35,6 @@ def __repr__(self):
# type: () -> str
return str(self)

def __len__(self):
# type: () -> int
return len(self.keys())

def __delitem__(self, key):
# type: (Any) -> None
self.__dict__[key] = None
Expand Down Expand Up @@ -79,6 +81,10 @@ def get(self, key, default=None):
return self.__dict__[key]
return default

def __len__(self):
# type: () -> int
return len(self.keys())


class TelemetryMixin:
def _get_telemetry_values(self, *args, **kwargs): # pylint: disable=unused-argument
Expand Down