diff --git a/msrestazure/azure_active_directory.py b/msrestazure/azure_active_directory.py index cda6808..c7e53ed 100644 --- a/msrestazure/azure_active_directory.py +++ b/msrestazure/azure_active_directory.py @@ -28,10 +28,11 @@ import re import time try: - from urlparse import urlparse, parse_qs + from urlparse import urljoin, urlparse, parse_qs except ImportError: - from urllib.parse import urlparse, parse_qs + from urllib.parse import urljoin, urlparse, parse_qs +import adal import keyring from oauthlib.oauth2 import BackendApplicationClient, LegacyApplicationClient from oauthlib.oauth2.rfc6749.errors import ( @@ -42,6 +43,7 @@ from requests import RequestException import requests_oauthlib as oauth +import msrest.authentication from msrest.authentication import OAuthTokenAuthentication from msrest.exceptions import TokenExpiredError as Expired from msrest.exceptions import ( @@ -525,3 +527,53 @@ def set_token(self, response_url): raise_with_traceback(AuthenticationError, "", err) else: self.token = token + + +# Constants related to AAD-based authentication methods. +_TOKEN_ENTRY_TOKEN_TYPE = 'tokenType' +_ACCESS_TOKEN = 'accessToken' +XPLAT_APP_ID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46" + + +class AdalAuthentication(msrest.authentication.Authentication): + + """Base class for adal-derived authentication.""" + + def __init__(self, client_id, + tenant="common", + auth_endpoint="https://login.microsoftonline.com", + resource="https://management.core.windows.net/"): + """Handle details common to adal.""" + super(AdalAuthentication, self).__init__() + self.client_id = client_id + self.authority = urljoin(auth_endpoint, tenant) + self.resource = resource + context = adal.AuthenticationContext(self.authority) + self.token = self.acquire_token(context) + + # @abc.abstractmethod if Python 2 wasn't supported. + def acquire_token(self, context): + """Override with code that returns an adal.acquire_*() call.""" + raise NotImplementedError + + def signed_session(self): + """Return a signed session.""" + session = super(AdalAuthentication, self).signed_session() + header = " ".join([self.token[_TOKEN_ENTRY_TOKEN_TYPE], + self.token[_ACCESS_TOKEN]]) + session.headers['Authorization'] = header + return session + + +class AdalUserPassCredentials(AdalAuthentication): + + """Authenticate with AAD using a username and password.""" + + def __init__(self, username, password, client_id, **kwargs): + self.username = username + self.password = password + super(AdalUserPassCredentials, self).__init__(client_id, **kwargs) + + def acquire_token(self, context): + return context.acquire_token_with_username_password( + self.resource, self.username, self.password, self.client_id) diff --git a/setup.py b/setup.py index 2c8587f..c30f1d1 100644 --- a/setup.py +++ b/setup.py @@ -50,5 +50,6 @@ 'Topic :: Software Development'], install_requires=[ "msrest~=0.4.4", + "adal>=0.4.0", "keyring>=5.6"], ) diff --git a/test/unittest_auth.py b/test/unittest_auth.py index 363b4b8..5b4ec8a 100644 --- a/test/unittest_auth.py +++ b/test/unittest_auth.py @@ -1,6 +1,6 @@ #-------------------------------------------------------------------------- # -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. All rights reserved. # # The MIT License (MIT) # @@ -31,6 +31,10 @@ from unittest import mock except ImportError: import mock +try: + from urlparse import urljoin +except ImportError: + from urllib.parse import urljoin from requests_oauthlib import OAuth2Session import oauthlib @@ -54,7 +58,7 @@ class TestInteractiveCredentials(unittest.TestCase): def setUp(self): self.cfg = AzureConfiguration("https://my_service.com") return super(TestInteractiveCredentials, self).setUp() - + def test_http(self): test_uri = "http://my_service.com" @@ -282,8 +286,8 @@ def test_service_principal(self): session = mock.create_autospec(OAuth2Session) with mock.patch.object( ServicePrincipalCredentials, '_setup_session', return_value=session): - - creds = ServicePrincipalCredentials("client_id", "secret", + + creds = ServicePrincipalCredentials("client_id", "secret", verify=False, tenant="private") session.fetch_token.assert_called_with( @@ -294,7 +298,7 @@ def test_service_principal(self): with mock.patch.object( ServicePrincipalCredentials, '_setup_session', return_value=session): - + creds = ServicePrincipalCredentials("client_id", "secret", china=True, verify=False, tenant="private") @@ -336,8 +340,8 @@ def test_user_pass_credentials(self): session = mock.create_autospec(OAuth2Session) with mock.patch.object( UserPassCredentials, '_setup_session', return_value=session): - - creds = UserPassCredentials("my_username", "my_password", + + creds = UserPassCredentials("my_username", "my_password", verify=False, tenant="private", resource='resource') session.fetch_token.assert_called_with( @@ -347,7 +351,7 @@ def test_user_pass_credentials(self): with mock.patch.object( UserPassCredentials, '_setup_session', return_value=session): - + creds = UserPassCredentials("my_username", "my_password", client_id="client_id", verify=False, tenant="private", china=True) @@ -357,5 +361,64 @@ def test_user_pass_credentials(self): password='my_password', resource='https://management.core.chinacloudapi.cn/', verify=False) +class TestAdalAuthentication(unittest.TestCase): + + """Test authentication using adal.""" + + def test_base_init(self): + # Test azure_active_directory.AdalAuthentication.__init__(). + endpoint = "https://localhost" + tenant = "test-tenant" + token_data = object() # Sentinel. + class FakeAuth(azure_active_directory.AdalAuthentication): + def acquire_token(self, context): + self.context = context + return token_data + + auth = FakeAuth( + azure_active_directory.XPLAT_APP_ID, + auth_endpoint=endpoint, + tenant=tenant) + self.assertEqual(auth.authority, urljoin(endpoint, tenant)) + self.assertIs(auth.token, token_data) + + @mock.patch("adal.AuthenticationContext") + def test_signed_session(self, AuthMock): + # Test azure_active_directory.AdalAuthentication.signed_session(). + sentinel = object() + AuthMock.return_value = sentinel + access_token = "access-token" + token_type = "token-type" + token_data = {azure_active_directory._TOKEN_ENTRY_TOKEN_TYPE: token_type, + azure_active_directory._ACCESS_TOKEN: access_token} + class FakeAuth(azure_active_directory.AdalAuthentication): + def acquire_token(self, context): + self.context = context + return token_data + + auth =FakeAuth(azure_active_directory.XPLAT_APP_ID) + token = auth.signed_session() + self.assertEqual(AuthMock.call_args, mock.call(auth.authority)) + self.assertEqual(token.headers["Authorization"], + " ".join([token_type, access_token])) + self.assertIs(auth.context, sentinel) + + @mock.patch("adal.AuthenticationContext") + def test_username_password(self, mocked_context): + # Test azure_active_directory.AdaluserPassCredentials.__init__() calls + # context.acquire_token_with_username_password(). + username = 'msrestazure-test' + password = 'password' + context = mock.Mock() + mocked_context.return_value = context + auth = azure_active_directory.AdalUserPassCredentials( + username, + password, + azure_active_directory.XPLAT_APP_ID) + acquire_call = mock.call.acquire_token_with_username_password( + auth.resource, username, password, auth.client_id) + self.assertEqual(context.method_calls, [acquire_call]) + + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main()