Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions msal/managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ class ManagedIdentity(UserDict):

_types_mapping = { # Maps type name in configuration to type name on wire
CLIENT_ID: "client_id",
RESOURCE_ID: "mi_res_id",
RESOURCE_ID: "msi_res_id", # VM's IMDS prefers msi_res_id https://github.com/Azure/azure-rest-api-specs/blob/dba6ed1f03bda88ac6884c0a883246446cc72495/specification/imds/data-plane/Microsoft.InstanceMetadataService/stable/2018-10-01/imds.json#L233-L239
OBJECT_ID: "object_id",
}

@classmethod
def is_managed_identity(cls, unknown):
return isinstance(unknown, ManagedIdentity) or (
isinstance(unknown, dict) and cls.ID_TYPE in unknown)
return (isinstance(unknown, ManagedIdentity)
or cls.is_system_assigned(unknown)
or cls.is_user_assigned(unknown))

@classmethod
def is_system_assigned(cls, unknown):
Expand Down Expand Up @@ -217,6 +218,9 @@ def __init__(
)
token = client.acquire_token_for_client("resource")
"""
if not ManagedIdentity.is_managed_identity(managed_identity):
raise ManagedIdentityError(
f"Incorrect managed_identity: {managed_identity}")
self._managed_identity = managed_identity
self._http_client = _ThrottledHttpClient(
# This class only throttles excess token acquisition requests.
Expand Down Expand Up @@ -426,9 +430,9 @@ def _obtain_token(http_client, managed_identity, resource):
return _obtain_token_on_azure_vm(http_client, managed_identity, resource)


def _adjust_param(params, managed_identity):
def _adjust_param(params, managed_identity, types_mapping=None):
# Modify the params dict in place
id_name = ManagedIdentity._types_mapping.get(
id_name = (types_mapping or ManagedIdentity._types_mapping).get(
managed_identity.get(ManagedIdentity.ID_TYPE))
if id_name:
params[id_name] = managed_identity[ManagedIdentity.ID]
Expand Down Expand Up @@ -475,7 +479,12 @@ def _obtain_token_on_app_service(
"api-version": "2019-08-01",
"resource": resource,
}
_adjust_param(params, managed_identity)
_adjust_param(params, managed_identity, types_mapping={
ManagedIdentity.CLIENT_ID: "client_id",
ManagedIdentity.RESOURCE_ID: "mi_res_id", # App Service's resource id uses "mi_res_id"
ManagedIdentity.OBJECT_ID: "object_id",
})

resp = http_client.get(
endpoint,
params=params,
Expand Down
43 changes: 43 additions & 0 deletions tests/test_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ def setUp(self):
http_client=requests.Session(),
)

def test_error_out_on_invalid_input(self):
with self.assertRaises(ManagedIdentityError):
ManagedIdentityClient({"foo": "bar"}, http_client=requests.Session())
with self.assertRaises(ManagedIdentityError):
ManagedIdentityClient(
{"ManagedIdentityIdType": "undefined", "Id": "foo"},
http_client=requests.Session())

def assertCacheStatus(self, app):
cache = app._token_cache._cache
self.assertEqual(1, len(cache.get("AccessToken", [])), "Should have 1 AT")
Expand Down Expand Up @@ -131,6 +139,22 @@ def test_vm_error_should_be_returned_as_is(self):
json.loads(raw_error), self.app.acquire_token_for_client(resource="R"))
self.assertEqual({}, self.app._token_cache._cache)

def test_vm_resource_id_parameter_should_be_msi_res_id(self):
app = ManagedIdentityClient(
{"ManagedIdentityIdType": "ResourceId", "Id": "1234"},
http_client=requests.Session(),
)
with patch.object(app._http_client, "get", return_value=MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_in": 3600, "resource": "R"}',
)) as mocked_method:
app.acquire_token_for_client(resource="R")
mocked_method.assert_called_with(
'http://169.254.169.254/metadata/identity/oauth2/token',
params={'api-version': '2018-02-01', 'resource': 'R', 'msi_res_id': '1234'},
headers={'Metadata': 'true'},
)


@patch.dict(os.environ, {"IDENTITY_ENDPOINT": "http://localhost", "IDENTITY_HEADER": "foo"})
class AppServiceTestCase(ClientTestCase):
Expand All @@ -156,6 +180,22 @@ def test_app_service_error_should_be_normalized(self):
}, self.app.acquire_token_for_client(resource="R"))
self.assertEqual({}, self.app._token_cache._cache)

def test_app_service_resource_id_parameter_should_be_mi_res_id(self):
app = ManagedIdentityClient(
{"ManagedIdentityIdType": "ResourceId", "Id": "1234"},
http_client=requests.Session(),
)
with patch.object(app._http_client, "get", return_value=MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_on": 12345, "resource": "R"}',
)) as mocked_method:
app.acquire_token_for_client(resource="R")
mocked_method.assert_called_with(
'http://localhost',
params={'api-version': '2019-08-01', 'resource': 'R', 'mi_res_id': '1234'},
headers={'X-IDENTITY-HEADER': 'foo', 'Metadata': 'true'},
)


@patch.dict(os.environ, {"MSI_ENDPOINT": "http://localhost", "MSI_SECRET": "foo"})
class MachineLearningTestCase(ClientTestCase):
Expand Down Expand Up @@ -241,6 +281,9 @@ class ArcTestCase(ClientTestCase):
"WWW-Authenticate": "Basic realm=/tmp/foo",
})

def test_error_out_on_invalid_input(self, mocked_stat):
return super(ArcTestCase, self).test_error_out_on_invalid_input()

def test_happy_path(self, mocked_stat):
expires_in = 1234
with patch.object(self.app._http_client, "get", side_effect=[
Expand Down