Skip to content

Commit 0a756e9

Browse files
committed
Error out on invalid ManagedIdentity dict
1 parent 85c93f8 commit 0a756e9

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

msal/managed_identity.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@ class ManagedIdentity(UserDict):
4646

4747
@classmethod
4848
def is_managed_identity(cls, unknown):
49-
return isinstance(unknown, ManagedIdentity) or (
50-
isinstance(unknown, dict) and cls.ID_TYPE in unknown)
49+
return (isinstance(unknown, ManagedIdentity)
50+
or cls.is_system_assigned(unknown)
51+
or cls.is_user_assigned(unknown))
5152

5253
@classmethod
5354
def is_system_assigned(cls, unknown):
@@ -217,6 +218,9 @@ def __init__(
217218
)
218219
token = client.acquire_token_for_client("resource")
219220
"""
221+
if not ManagedIdentity.is_managed_identity(managed_identity):
222+
raise ManagedIdentityError(
223+
f"Incorrect managed_identity: {managed_identity}")
220224
self._managed_identity = managed_identity
221225
self._http_client = _ThrottledHttpClient(
222226
# This class only throttles excess token acquisition requests.

tests/test_mi.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ def setUp(self):
6161
http_client=requests.Session(),
6262
)
6363

64+
def test_error_out_on_invalid_input(self):
65+
with self.assertRaises(ManagedIdentityError):
66+
ManagedIdentityClient({"foo": "bar"}, http_client=requests.Session())
67+
with self.assertRaises(ManagedIdentityError):
68+
ManagedIdentityClient(
69+
{"ManagedIdentityIdType": "undefined", "Id": "foo"},
70+
http_client=requests.Session())
71+
6472
def assertCacheStatus(self, app):
6573
cache = app._token_cache._cache
6674
self.assertEqual(1, len(cache.get("AccessToken", [])), "Should have 1 AT")
@@ -241,6 +249,9 @@ class ArcTestCase(ClientTestCase):
241249
"WWW-Authenticate": "Basic realm=/tmp/foo",
242250
})
243251

252+
def test_error_out_on_invalid_input(self, mocked_stat):
253+
return super(ArcTestCase, self).test_error_out_on_invalid_input()
254+
244255
def test_happy_path(self, mocked_stat):
245256
expires_in = 1234
246257
with patch.object(self.app._http_client, "get", side_effect=[

0 commit comments

Comments
 (0)