diff --git a/sdk/ml/azure-ai-ml/CHANGELOG.md b/sdk/ml/azure-ai-ml/CHANGELOG.md index d188db48d4a3..36ac6e3b50d4 100644 --- a/sdk/ml/azure-ai-ml/CHANGELOG.md +++ b/sdk/ml/azure-ai-ml/CHANGELOG.md @@ -3,6 +3,7 @@ ## 1.1.0 (Unreleased) ### Features Added +- Registry list operation now accepts scope value to allow subscription-only based requests. - Most configuration classes from the entity package now implement the standard mapping protocol. ### Breaking Changes diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/constants/__init__.py b/sdk/ml/azure-ai-ml/azure/ai/ml/constants/__init__.py index 63c87963b663..1d2e41955255 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/constants/__init__.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/constants/__init__.py @@ -4,7 +4,7 @@ __path__ = __import__("pkgutil").extend_path(__path__, __name__) -from ._common import AssetTypes, InputOutputModes, ModelType, TimeZone +from ._common import AssetTypes, InputOutputModes, ModelType, TimeZone, Scope from ._component import ParallelTaskType from ._deployment import BatchDeploymentOutputAction from ._job import ( @@ -38,4 +38,5 @@ "AcrAccountSku", "NlpModels", "NlpLearningRateScheduler", + "Scope", ] diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py b/sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py index 2628287d8f6c..4c09996db29e 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py @@ -562,3 +562,8 @@ class RollingRate: DAY = "day" HOUR = "hour" MINUTE = "minute" + + +class Scope: + SUBSCRIPTION="subscription" + RESOURCE_GROUP="resource_group" diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_registry_operations.py b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_registry_operations.py index 1d9c22313bb5..512776c0d485 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_registry_operations.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_registry_operations.py @@ -19,7 +19,7 @@ from azure.ai.ml._utils._experimental import experimental from .._utils._azureml_polling import AzureMLPolling -from ..constants._common import LROConfigurations +from ..constants._common import LROConfigurations, Scope ops_logger = OpsLogger(__name__) module_logger = ops_logger.module_logger @@ -52,14 +52,19 @@ def __init__( self._init_kwargs = kwargs #@ monitor_with_activity(logger, "Registry.List", ActivityType.PUBLICAPI) - def list(self) -> Iterable[Registry]: + def list(self, *, scope: str = Scope.RESOURCE_GROUP) -> Iterable[Registry]: """List all registries that the user has access to in the current - resource group. + resource group or subscription. + :param scope: scope of the listing, "resource_group" or "subscription", defaults to "resource_group" + :type scope: str, optional :return: An iterator like instance of Registry objects :rtype: ~azure.core.paging.ItemPaged[Registry] """ - + if scope.lower() == Scope.SUBSCRIPTION: + return self._operation.list_by_subscription( + cls=lambda objs: [Registry._from_rest_object(obj) for obj in objs] + ) return self._operation.list(cls=lambda objs: [Registry._from_rest_object(obj) for obj in objs], \ resource_group_name=self._resource_group_name) diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_operations.py b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_operations.py index 0ddcbfaa2ceb..b8f513f5983a 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_operations.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_operations.py @@ -32,7 +32,7 @@ from azure.ai.ml._utils.utils import camel_to_snake from azure.ai.ml._version import VERSION from azure.ai.ml.constants import ManagedServiceIdentityType -from azure.ai.ml.constants._common import ArmConstants, LROConfigurations, WorkspaceResourceConstants +from azure.ai.ml.constants._common import ArmConstants, LROConfigurations, WorkspaceResourceConstants, Scope from azure.ai.ml.entities._credentials import IdentityConfiguration from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException from azure.core.credentials import TokenCredential @@ -70,7 +70,7 @@ def __init__( self.containerRegistry = "none" # @monitor_with_activity(logger, "Workspace.List", ActivityType.PUBLICAPI) - def list(self, *, scope: str = "resource_group") -> Iterable[Workspace]: + def list(self, *, scope: str = Scope.RESOURCE_GROUP) -> Iterable[Workspace]: """List all workspaces that the user has access to in the current resource group or subscription. @@ -80,7 +80,7 @@ def list(self, *, scope: str = "resource_group") -> Iterable[Workspace]: :rtype: ~azure.core.paging.ItemPaged[Workspace] """ - if scope == "subscription": + if scope == Scope.SUBSCRIPTION: return self._operation.list_by_subscription( cls=lambda objs: [Workspace._from_rest_object(obj) for obj in objs] ) diff --git a/sdk/ml/azure-ai-ml/tests/registry/unittests/test_registry_operations.py b/sdk/ml/azure-ai-ml/tests/registry/unittests/test_registry_operations.py index acb5d372d5b3..47f4013c13c1 100644 --- a/sdk/ml/azure-ai-ml/tests/registry/unittests/test_registry_operations.py +++ b/sdk/ml/azure-ai-ml/tests/registry/unittests/test_registry_operations.py @@ -29,9 +29,18 @@ def mock_registry_operation( @pytest.mark.unittest class TestRegistryOperations: def test_list(self, mock_registry_operation: RegistryOperations) -> None: + # Test different input options for the scope value mock_registry_operation.list() mock_registry_operation._operation.list.assert_called_once() + mock_registry_operation.list(scope="invalid") + assert mock_registry_operation._operation.list.call_count == 2 + mock_registry_operation._operation.list_by_subscription.assert_not_called() + + mock_registry_operation.list(scope="subscription") + assert mock_registry_operation._operation.list.call_count == 2 + mock_registry_operation._operation.list_by_subscription.assert_called_once() + def test_get(self, mock_registry_operation: RegistryOperations, randstr: Callable[[], str]) -> None: mock_registry_operation.get(f"unittest_{randstr('reg_name')}") mock_registry_operation._operation.get.assert_called_once()