diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_azure_environments.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_azure_environments.py index 325173b229c7..c42317356990 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_azure_environments.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_azure_environments.py @@ -10,6 +10,11 @@ from azure.ai.ml._utils.utils import _get_mfe_url_override from azure.ai.ml.constants._common import AZUREML_CLOUD_ENV_NAME +from azure.ai.ml.constants._common import ArmConstants +from azure.core.rest import HttpRequest +from azure.mgmt.core import ARMPipelineClient + + module_logger = logging.getLogger(__name__) @@ -56,6 +61,19 @@ class EndpointURLS: # pylint: disable=too-few-public-methods,no-init }, } +_requests_pipeline = None + +def _get_cloud(cloud: str): + if cloud in _environments: + return _environments[cloud] + arm_url = os.environ.get(ArmConstants.METADATA_URL_ENV_NAME,ArmConstants.DEFAULT_URL) + arm_clouds = _get_clouds_by_metadata_url(arm_url) + try: + new_cloud = arm_clouds[cloud] + _environments.update(new_cloud) + return new_cloud + except KeyError: + raise Exception('Unknown cloud environment "{0}".'.format(cloud)) def _get_default_cloud_name(): """Return AzureCloud as the default cloud.""" @@ -74,17 +92,18 @@ def _get_cloud_details(cloud: str = AzureEnvironments.ENV_DEFAULT): AzureEnvironments.ENV_DEFAULT, ) cloud = _get_default_cloud_name() - try: - azure_environment = _environments[cloud] - module_logger.debug("Using the cloud configuration: '%s'.", azure_environment) - except KeyError: - raise Exception('Unknown cloud environment "{0}".'.format(cloud)) - return azure_environment + return _get_cloud(cloud) def _set_cloud(cloud: str = AzureEnvironments.ENV_DEFAULT): + """Sets the current cloud + + :param cloud: cloud name + """ if cloud is not None: - if cloud not in _environments: + try: + _get_cloud(cloud) + except Exception: raise Exception('Unknown cloud environment supplied: "{0}".'.format(cloud)) else: cloud = _get_default_cloud_name() @@ -189,3 +208,74 @@ def _resource_to_scopes(resource): """ scope = resource + "/.default" return [scope] + +def _get_registry_discovery_url(cloud, cloud_suffix=""): + """Get or generate the registry discovery url + + :param cloud: configuration of the cloud to get the registry_discovery_url from + :param cloud_suffix: the suffix to use for the cloud, in the case that the registry_discovery_url + must be generated + :return: string of discovery url + """ + cloud_name = cloud["name"] + if cloud_name in _environments: + return _environments[cloud_name].registry_url + + registry_discovery_region = os.environ.get( + ArmConstants.REGISTRY_DISCOVERY_REGION_ENV_NAME, + ArmConstants.REGISTRY_DISCOVERY_DEFAULT_REGION + ) + registry_discovery_region_default = "https://{}{}.api.azureml.{}/".format( + cloud_name.lower(), + registry_discovery_region, + cloud_suffix + ) + return os.environ.get(ArmConstants.REGISTRY_ENV_URL, registry_discovery_region_default) + +def _get_clouds_by_metadata_url(metadata_url): + """Get all the clouds by the specified metadata url + + :return: list of the clouds + """ + try: + module_logger.debug('Start : Loading cloud metadata from the url specified by %s', metadata_url) + client = ARMPipelineClient(base_url=metadata_url, policies=[]) + HttpRequest("GET", metadata_url) + with client.send_request(HttpRequest("GET", metadata_url)) as meta_response: + arm_cloud_dict = meta_response.json() + cli_cloud_dict = _convert_arm_to_cli(arm_cloud_dict) + module_logger.debug('Finish : Loading cloud metadata from the url specified by %s', metadata_url) + return cli_cloud_dict + except Exception as ex: # pylint: disable=broad-except + module_logger.warning("Error: Azure ML was unable to load cloud metadata from the url specified by %s. %s. " + "This may be due to a misconfiguration of networking controls. Azure Machine Learning Python " + "SDK requires outbound access to Azure Resource Manager. Please contact your networking team " + "to configure outbound access to Azure Resource Manager on both Network Security Group and " + "Firewall. For more details on required configurations, see " + "https://docs.microsoft.com/azure/machine-learning/how-to-access-azureml-behind-firewall.", + metadata_url, ex) + return {} + +def _convert_arm_to_cli(arm_cloud_metadata): + cli_cloud_metadata_dict = {} + if isinstance(arm_cloud_metadata, dict): + arm_cloud_metadata = [arm_cloud_metadata] + + for cloud in arm_cloud_metadata: + try: + cloud_name = cloud["name"] + portal_endpoint = cloud["portal"] + cloud_suffix = ".".join(portal_endpoint.split('.')[2:]).replace("/", "") + registry_discovery_url = _get_registry_discovery_url(cloud, cloud_suffix) + cli_cloud_metadata_dict[cloud_name] = { + EndpointURLS.AZURE_PORTAL_ENDPOINT: cloud["portal"], + EndpointURLS.RESOURCE_MANAGER_ENDPOINT: cloud["resourceManager"], + EndpointURLS.ACTIVE_DIRECTORY_ENDPOINT: cloud["authentication"]["loginEndpoint"], + EndpointURLS.AML_RESOURCE_ID: "https://ml.azure.{}".format(cloud_suffix), + EndpointURLS.STORAGE_ENDPOINT: cloud["suffixes"]["storage"], + EndpointURLS.REGISTRY_DISCOVERY_ENDPOINT: registry_discovery_url + } + except KeyError as ex: + module_logger.warning("Property on cloud not found in arm cloud metadata: %s", ex) + continue + return cli_cloud_metadata_dict diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py b/sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py index 1832f8641c2e..fd01fd4ca5e9 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py @@ -273,6 +273,12 @@ class ArmConstants(object): AZURE_MGMT_KEYVAULT_API_VERSION = "2019-09-01" AZURE_MGMT_CONTAINER_REG_API_VERSION = "2019-05-01" + DEFAULT_URL = "https://management.azure.com/metadata/endpoints?api-version=2019-05-01" + METADATA_URL_ENV_NAME = "ARM_CLOUD_METADATA_URL" + REGISTRY_DISCOVERY_DEFAULT_REGION = "west" + REGISTRY_DISCOVERY_REGION_ENV_NAME = "REGISTRY_DISCOVERY_ENDPOINT_REGION" + REGISTRY_ENV_URL = "REGISTRY_DISCOVERY_ENDPOINT_URL" + class HttpResponseStatusCode(object): NOT_FOUND = 404 diff --git a/sdk/ml/azure-ai-ml/tests/internal_utils/unittests/test_cloud_environments.py b/sdk/ml/azure-ai-ml/tests/internal_utils/unittests/test_cloud_environments.py index be79f6fd68a1..93f0df644dee 100644 --- a/sdk/ml/azure-ai-ml/tests/internal_utils/unittests/test_cloud_environments.py +++ b/sdk/ml/azure-ai-ml/tests/internal_utils/unittests/test_cloud_environments.py @@ -1,24 +1,66 @@ import os - import mock import pytest +from mock import MagicMock, patch from azure.ai.ml._azure_environments import ( AzureEnvironments, + EndpointURLS, _get_azure_portal_id_from_metadata, _get_base_url_from_metadata, + _get_cloud_details, _get_cloud_information_from_metadata, _get_default_cloud_name, _get_registry_discovery_endpoint_from_metadata, _get_storage_endpoint_from_metadata, _set_cloud, ) -from azure.ai.ml.constants._common import AZUREML_CLOUD_ENV_NAME +from azure.ai.ml.constants._common import ArmConstants, AZUREML_CLOUD_ENV_NAME +from azure.mgmt.core import ARMPipelineClient + +def mocked_send_request_get(*args, **kwargs): + class MockResponse: + def __init__(self): + self.status_code = 201 + def __enter__(self): + return self + def __exit__(self, exc_type, exc_value, traceback): + return + def json(self): + return [ + { + "name": "TEST_ENV", + "portal": "testportal.azure.com", + "resourceManager": "testresourcemanager.azure.com", + "authentication": { + "loginEndpoint": "testdirectoryendpoint.azure.com" + }, + "suffixes": { + "storage": "teststorageendpoint" + } + }, + { + "name": "TEST_ENV2", + "portal": "testportal.azure.windows.net", + "resourceManager": "testresourcemanager.azure.com", + "authentication": { + "loginEndpoint": "testdirectoryendpoint.azure.com" + }, + "suffixes": { + "storage": "teststorageendpoint" + } + }, + { + "name": "MISCONFIGURED" + } + ] + return MockResponse() @pytest.mark.unittest @pytest.mark.core_sdk_test class TestCloudEnvironments: + @mock.patch.dict(os.environ, {AZUREML_CLOUD_ENV_NAME: AzureEnvironments.ENV_DEFAULT}, clear=True) def test_set_valid_cloud_details_china(self): cloud_environment = AzureEnvironments.ENV_CHINA @@ -70,7 +112,6 @@ def test_get_default_cloud(self): with mock.patch("os.environ", {AZUREML_CLOUD_ENV_NAME: "yadadada"}): cloud_name = _get_default_cloud_name() assert cloud_name == "yadadada" - def test_get_registry_endpoint_from_public(self): cloud_environment = AzureEnvironments.ENV_DEFAULT @@ -88,4 +129,36 @@ def test_get_registry_endpoint_from_us_gov(self): cloud_environment = AzureEnvironments.ENV_US_GOVERNMENT _set_cloud(cloud_environment) base_url = _get_registry_discovery_endpoint_from_metadata(cloud_environment) - assert "https://usgovarizona.api.ml.azure.us/" in base_url \ No newline at end of file + assert "https://usgovarizona.api.ml.azure.us/" in base_url + + @mock.patch.dict(os.environ, {}, clear=True) + @mock.patch("azure.mgmt.core.ARMPipelineClient.send_request", side_effect=mocked_send_request_get) + def test_get_cloud_from_arm(self, mock_arm_pipeline_client_send_request): + + _set_cloud('TEST_ENV') + cloud_details = _get_cloud_information_from_metadata("TEST_ENV") + assert cloud_details.get("cloud") == "TEST_ENV" + + @mock.patch.dict(os.environ, {}, clear=True) + @mock.patch("azure.mgmt.core.ARMPipelineClient.send_request", side_effect=mocked_send_request_get) + def test_all_endpointurls_used(self, mock_get): + cloud_details = _get_cloud_details("TEST_ENV") + endpoint_urls = [a for a in dir(EndpointURLS) if not a.startswith('__')] + for url in endpoint_urls: + try: + cloud_details[EndpointURLS.__dict__[url]] + except: + assert False, "Url not found: {}".format(EndpointURLS.__dict__[url]) + assert True + + @mock.patch.dict(os.environ, {}, clear=True) + @mock.patch("azure.mgmt.core.ARMPipelineClient.send_request", side_effect=mocked_send_request_get) + def test_metadata_registry_endpoint(self, mock_get): + cloud_details = _get_cloud_details("TEST_ENV2") + assert cloud_details.get(EndpointURLS.REGISTRY_DISCOVERY_ENDPOINT) == "https://test_env2west.api.azureml.windows.net/" + + @mock.patch.dict(os.environ, {}, clear=True) + @mock.patch("azure.mgmt.core.ARMPipelineClient.send_request", side_effect=mocked_send_request_get) + def test_arm_misconfigured(self, mock_get): + with pytest.raises(Exception) as e_info: + _set_cloud("MISCONFIGURED")