diff --git a/msal/managed_identity.py b/msal/managed_identity.py index 181d34c3..608bc1bf 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -40,14 +40,15 @@ class ManagedIdentity(UserDict): _types_mapping = { # Maps type name in configuration to type name on wire CLIENT_ID: "client_id", - RESOURCE_ID: "mi_res_id", + RESOURCE_ID: "msi_res_id", # VM's IMDS prefers msi_res_id https://github.com/Azure/azure-rest-api-specs/blob/dba6ed1f03bda88ac6884c0a883246446cc72495/specification/imds/data-plane/Microsoft.InstanceMetadataService/stable/2018-10-01/imds.json#L233-L239 OBJECT_ID: "object_id", } @classmethod def is_managed_identity(cls, unknown): - return isinstance(unknown, ManagedIdentity) or ( - isinstance(unknown, dict) and cls.ID_TYPE in unknown) + return (isinstance(unknown, ManagedIdentity) + or cls.is_system_assigned(unknown) + or cls.is_user_assigned(unknown)) @classmethod def is_system_assigned(cls, unknown): @@ -217,6 +218,9 @@ def __init__( ) token = client.acquire_token_for_client("resource") """ + if not ManagedIdentity.is_managed_identity(managed_identity): + raise ManagedIdentityError( + f"Incorrect managed_identity: {managed_identity}") self._managed_identity = managed_identity self._http_client = _ThrottledHttpClient( # This class only throttles excess token acquisition requests. @@ -426,9 +430,9 @@ def _obtain_token(http_client, managed_identity, resource): return _obtain_token_on_azure_vm(http_client, managed_identity, resource) -def _adjust_param(params, managed_identity): +def _adjust_param(params, managed_identity, types_mapping=None): # Modify the params dict in place - id_name = ManagedIdentity._types_mapping.get( + id_name = (types_mapping or ManagedIdentity._types_mapping).get( managed_identity.get(ManagedIdentity.ID_TYPE)) if id_name: params[id_name] = managed_identity[ManagedIdentity.ID] @@ -475,7 +479,12 @@ def _obtain_token_on_app_service( "api-version": "2019-08-01", "resource": resource, } - _adjust_param(params, managed_identity) + _adjust_param(params, managed_identity, types_mapping={ + ManagedIdentity.CLIENT_ID: "client_id", + ManagedIdentity.RESOURCE_ID: "mi_res_id", # App Service's resource id uses "mi_res_id" + ManagedIdentity.OBJECT_ID: "object_id", + }) + resp = http_client.get( endpoint, params=params, diff --git a/tests/test_mi.py b/tests/test_mi.py index 2041419d..1f33fe73 100644 --- a/tests/test_mi.py +++ b/tests/test_mi.py @@ -61,6 +61,14 @@ def setUp(self): http_client=requests.Session(), ) + def test_error_out_on_invalid_input(self): + with self.assertRaises(ManagedIdentityError): + ManagedIdentityClient({"foo": "bar"}, http_client=requests.Session()) + with self.assertRaises(ManagedIdentityError): + ManagedIdentityClient( + {"ManagedIdentityIdType": "undefined", "Id": "foo"}, + http_client=requests.Session()) + def assertCacheStatus(self, app): cache = app._token_cache._cache self.assertEqual(1, len(cache.get("AccessToken", [])), "Should have 1 AT") @@ -131,6 +139,22 @@ def test_vm_error_should_be_returned_as_is(self): json.loads(raw_error), self.app.acquire_token_for_client(resource="R")) self.assertEqual({}, self.app._token_cache._cache) + def test_vm_resource_id_parameter_should_be_msi_res_id(self): + app = ManagedIdentityClient( + {"ManagedIdentityIdType": "ResourceId", "Id": "1234"}, + http_client=requests.Session(), + ) + with patch.object(app._http_client, "get", return_value=MinimalResponse( + status_code=200, + text='{"access_token": "AT", "expires_in": 3600, "resource": "R"}', + )) as mocked_method: + app.acquire_token_for_client(resource="R") + mocked_method.assert_called_with( + 'http://169.254.169.254/metadata/identity/oauth2/token', + params={'api-version': '2018-02-01', 'resource': 'R', 'msi_res_id': '1234'}, + headers={'Metadata': 'true'}, + ) + @patch.dict(os.environ, {"IDENTITY_ENDPOINT": "http://localhost", "IDENTITY_HEADER": "foo"}) class AppServiceTestCase(ClientTestCase): @@ -156,6 +180,22 @@ def test_app_service_error_should_be_normalized(self): }, self.app.acquire_token_for_client(resource="R")) self.assertEqual({}, self.app._token_cache._cache) + def test_app_service_resource_id_parameter_should_be_mi_res_id(self): + app = ManagedIdentityClient( + {"ManagedIdentityIdType": "ResourceId", "Id": "1234"}, + http_client=requests.Session(), + ) + with patch.object(app._http_client, "get", return_value=MinimalResponse( + status_code=200, + text='{"access_token": "AT", "expires_on": 12345, "resource": "R"}', + )) as mocked_method: + app.acquire_token_for_client(resource="R") + mocked_method.assert_called_with( + 'http://localhost', + params={'api-version': '2019-08-01', 'resource': 'R', 'mi_res_id': '1234'}, + headers={'X-IDENTITY-HEADER': 'foo', 'Metadata': 'true'}, + ) + @patch.dict(os.environ, {"MSI_ENDPOINT": "http://localhost", "MSI_SECRET": "foo"}) class MachineLearningTestCase(ClientTestCase): @@ -241,6 +281,9 @@ class ArcTestCase(ClientTestCase): "WWW-Authenticate": "Basic realm=/tmp/foo", }) + def test_error_out_on_invalid_input(self, mocked_stat): + return super(ArcTestCase, self).test_error_out_on_invalid_input() + def test_happy_path(self, mocked_stat): expires_in = 1234 with patch.object(self.app._http_client, "get", side_effect=[