diff --git a/msrestazure/azure_active_directory.py b/msrestazure/azure_active_directory.py index cda6808..c8e52a3 100644 --- a/msrestazure/azure_active_directory.py +++ b/msrestazure/azure_active_directory.py @@ -33,20 +33,19 @@ from urllib.parse import urlparse, parse_qs import keyring +import adal from oauthlib.oauth2 import BackendApplicationClient, LegacyApplicationClient from oauthlib.oauth2.rfc6749.errors import ( InvalidGrantError, MismatchingStateError, OAuth2Error, TokenExpiredError) -from requests import RequestException +from requests import RequestException, ConnectionError import requests_oauthlib as oauth -from msrest.authentication import OAuthTokenAuthentication +from msrest.authentication import OAuthTokenAuthentication, Authentication from msrest.exceptions import TokenExpiredError as Expired -from msrest.exceptions import ( - AuthenticationError, - raise_with_traceback) +from msrest.exceptions import AuthenticationError, raise_with_traceback def _build_url(uri, paths, scheme): @@ -525,3 +524,81 @@ def set_token(self, response_url): raise_with_traceback(AuthenticationError, "", err) else: self.token = token + + +class AdalAuthentication(Authentication): # pylint: disable=too-few-public-methods + """A wrapper to use ADAL for Python easily to authenticate on Azure.""" + + def __init__(self, adal_method, *args, **kwargs): + """Take an ADAL `acquire_token` method and its parameters. + + For example, this code from the ADAL tutorial: + + ```python + context = adal.AuthenticationContext('https://login.microsoftonline.com/ABCDEFGH-1234-1234-1234-ABCDEFGHIJKL') + RESOURCE = '00000002-0000-0000-c000-000000000000' #AAD graph resource + token = context.acquire_token_with_client_credentials( + RESOURCE, + "http://PythonSDK", + "Key-Configured-In-Portal") + ``` + + can be written here: + + ```python + context = adal.AuthenticationContext('https://login.microsoftonline.com/ABCDEFGH-1234-1234-1234-ABCDEFGHIJKL') + RESOURCE = '00000002-0000-0000-c000-000000000000' #AAD graph resource + credentials = AdalAuthentication( + context.acquire_token_with_client_credentials, + RESOURCE, + "http://PythonSDK", + "Key-Configured-In-Portal") + ``` + + or using a lambda if you prefer: + + ```python + context = adal.AuthenticationContext('https://login.microsoftonline.com/ABCDEFGH-1234-1234-1234-ABCDEFGHIJKL') + RESOURCE = '00000002-0000-0000-c000-000000000000' #AAD graph resource + credentials = AdalAuthentication( + lambda: context.acquire_token_with_client_credentials( + RESOURCE, + "http://PythonSDK", + "Key-Configured-In-Portal" + ) + ) + ``` + + :param adal_method: A lambda with no args, or `acquire_token` method with args using args/kwargs + :param args: Optional args for the method + :param kwargs: Optional kwargs for the method + """ + self._adal_method = adal_method + self._args = args + self._kwargs = kwargs + + def signed_session(self): + """Get a signed session for requests. + + Usually called by the Azure SDKs for you to authenticate queries. + + :rtype: requests.Session + """ + session = super(AdalAuthentication, self).signed_session() + + try: + raw_token = self._adal_method(*self._args, **self._kwargs) + except adal.AdalError as err: + # pylint: disable=no-member + if (hasattr(err, 'error_response') and ('error_description' in err.error_response) + and ('AADSTS70008:' in err.error_response['error_description'])): + raise Expired("Credentials have expired due to inactivity.") + else: + raise AuthenticationError(err) + except ConnectionError as err: + raise AuthenticationError('Please ensure you have network connection. Error detail: ' + str(err)) + + scheme, token = raw_token['tokenType'], raw_token['accessToken'] + header = "{} {}".format(scheme, token) + session.headers['Authorization'] = header + return session diff --git a/setup.py b/setup.py index 2c8587f..2164f38 100644 --- a/setup.py +++ b/setup.py @@ -50,5 +50,7 @@ 'Topic :: Software Development'], install_requires=[ "msrest~=0.4.4", - "keyring>=5.6"], + "keyring>=5.6", + "adal~=0.4.0" + ], ) diff --git a/test/unittest_auth.py b/test/unittest_auth.py index 363b4b8..871c882 100644 --- a/test/unittest_auth.py +++ b/test/unittest_auth.py @@ -34,6 +34,7 @@ from requests_oauthlib import OAuth2Session import oauthlib +import adal from msrestazure import AzureConfiguration from msrestazure import azure_active_directory @@ -41,12 +42,11 @@ AADMixin, InteractiveCredentials, ServicePrincipalCredentials, - UserPassCredentials - ) -from msrest.exceptions import ( - TokenExpiredError, - AuthenticationError, - ) + UserPassCredentials, + AdalAuthentication +) +from msrest.exceptions import TokenExpiredError, AuthenticationError +from requests import ConnectionError class TestInteractiveCredentials(unittest.TestCase): @@ -356,6 +356,34 @@ def test_user_pass_credentials(self): client_id="client_id", username='my_username', password='my_password', resource='https://management.core.chinacloudapi.cn/', verify=False) + def test_adal_authentication(self): + def success_auth(): + return { + 'tokenType': 'https', + 'accessToken': 'cryptictoken' + } + + credentials = AdalAuthentication(success_auth) + session = credentials.signed_session() + self.assertEquals(session.headers['Authorization'], 'https cryptictoken') + + def error(): + raise adal.AdalError("You hacker", {}) + credentials = AdalAuthentication(error) + with self.assertRaises(AuthenticationError) as cm: + session = credentials.signed_session() + + def expired(): + raise adal.AdalError("Too late", {'error_description': "AADSTS70008: Expired"}) + credentials = AdalAuthentication(expired) + with self.assertRaises(TokenExpiredError) as cm: + session = credentials.signed_session() + + def connection_error(): + raise ConnectionError("Plug the network") + credentials = AdalAuthentication(connection_error) + with self.assertRaises(AuthenticationError) as cm: + session = credentials.signed_session() if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main()