Skip to content

Commit 28fbf7c

Browse files
committed
Resource id adjustments
Resource ID as mi_res_id in App Service, as msi_res_id in other flavors
1 parent 0a756e9 commit 28fbf7c

File tree

2 files changed

+41
-4
lines changed

2 files changed

+41
-4
lines changed

msal/managed_identity.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class ManagedIdentity(UserDict):
4040

4141
_types_mapping = { # Maps type name in configuration to type name on wire
4242
CLIENT_ID: "client_id",
43-
RESOURCE_ID: "mi_res_id",
43+
RESOURCE_ID: "msi_res_id", # VM's IMDS prefers msi_res_id https://github.com/Azure/azure-rest-api-specs/blob/dba6ed1f03bda88ac6884c0a883246446cc72495/specification/imds/data-plane/Microsoft.InstanceMetadataService/stable/2018-10-01/imds.json#L233-L239
4444
OBJECT_ID: "object_id",
4545
}
4646

@@ -430,9 +430,9 @@ def _obtain_token(http_client, managed_identity, resource):
430430
return _obtain_token_on_azure_vm(http_client, managed_identity, resource)
431431

432432

433-
def _adjust_param(params, managed_identity):
433+
def _adjust_param(params, managed_identity, types_mapping=None):
434434
# Modify the params dict in place
435-
id_name = ManagedIdentity._types_mapping.get(
435+
id_name = (types_mapping or ManagedIdentity._types_mapping).get(
436436
managed_identity.get(ManagedIdentity.ID_TYPE))
437437
if id_name:
438438
params[id_name] = managed_identity[ManagedIdentity.ID]
@@ -479,7 +479,12 @@ def _obtain_token_on_app_service(
479479
"api-version": "2019-08-01",
480480
"resource": resource,
481481
}
482-
_adjust_param(params, managed_identity)
482+
_adjust_param(params, managed_identity, types_mapping={
483+
ManagedIdentity.CLIENT_ID: "client_id",
484+
ManagedIdentity.RESOURCE_ID: "mi_res_id", # App Service's resource id uses "mi_res_id"
485+
ManagedIdentity.OBJECT_ID: "object_id",
486+
})
487+
483488
resp = http_client.get(
484489
endpoint,
485490
params=params,

tests/test_mi.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,22 @@ def test_vm_error_should_be_returned_as_is(self):
139139
json.loads(raw_error), self.app.acquire_token_for_client(resource="R"))
140140
self.assertEqual({}, self.app._token_cache._cache)
141141

142+
def test_vm_resource_id_parameter_should_be_msi_res_id(self):
143+
app = ManagedIdentityClient(
144+
{"ManagedIdentityIdType": "ResourceId", "Id": "1234"},
145+
http_client=requests.Session(),
146+
)
147+
with patch.object(app._http_client, "get", return_value=MinimalResponse(
148+
status_code=200,
149+
text='{"access_token": "AT", "expires_in": 3600, "resource": "R"}',
150+
)) as mocked_method:
151+
app.acquire_token_for_client(resource="R")
152+
mocked_method.assert_called_with(
153+
'http://169.254.169.254/metadata/identity/oauth2/token',
154+
params={'api-version': '2018-02-01', 'resource': 'R', 'msi_res_id': '1234'},
155+
headers={'Metadata': 'true'},
156+
)
157+
142158

143159
@patch.dict(os.environ, {"IDENTITY_ENDPOINT": "http://localhost", "IDENTITY_HEADER": "foo"})
144160
class AppServiceTestCase(ClientTestCase):
@@ -164,6 +180,22 @@ def test_app_service_error_should_be_normalized(self):
164180
}, self.app.acquire_token_for_client(resource="R"))
165181
self.assertEqual({}, self.app._token_cache._cache)
166182

183+
def test_app_service_resource_id_parameter_should_be_mi_res_id(self):
184+
app = ManagedIdentityClient(
185+
{"ManagedIdentityIdType": "ResourceId", "Id": "1234"},
186+
http_client=requests.Session(),
187+
)
188+
with patch.object(app._http_client, "get", return_value=MinimalResponse(
189+
status_code=200,
190+
text='{"access_token": "AT", "expires_on": 12345, "resource": "R"}',
191+
)) as mocked_method:
192+
app.acquire_token_for_client(resource="R")
193+
mocked_method.assert_called_with(
194+
'http://localhost',
195+
params={'api-version': '2019-08-01', 'resource': 'R', 'mi_res_id': '1234'},
196+
headers={'X-IDENTITY-HEADER': 'foo', 'Metadata': 'true'},
197+
)
198+
167199

168200
@patch.dict(os.environ, {"MSI_ENDPOINT": "http://localhost", "MSI_SECRET": "foo"})
169201
class MachineLearningTestCase(ClientTestCase):

0 commit comments

Comments
 (0)