Skip to content
This repository was archived by the owner on Apr 10, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 82 additions & 5 deletions msrestazure/azure_active_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,19 @@
from urllib.parse import urlparse, parse_qs

import keyring
import adal
from oauthlib.oauth2 import BackendApplicationClient, LegacyApplicationClient
from oauthlib.oauth2.rfc6749.errors import (
InvalidGrantError,
MismatchingStateError,
OAuth2Error,
TokenExpiredError)
from requests import RequestException
from requests import RequestException, ConnectionError
import requests_oauthlib as oauth

from msrest.authentication import OAuthTokenAuthentication
from msrest.authentication import OAuthTokenAuthentication, Authentication
from msrest.exceptions import TokenExpiredError as Expired
from msrest.exceptions import (
AuthenticationError,
raise_with_traceback)
from msrest.exceptions import AuthenticationError, raise_with_traceback


def _build_url(uri, paths, scheme):
Expand Down Expand Up @@ -525,3 +524,81 @@ def set_token(self, response_url):
raise_with_traceback(AuthenticationError, "", err)
else:
self.token = token


class AdalAuthentication(Authentication): # pylint: disable=too-few-public-methods
"""A wrapper to use ADAL for Python easily to authenticate on Azure."""

def __init__(self, adal_method, *args, **kwargs):
"""Take an ADAL `acquire_token` method and its parameters.

For example, this code from the ADAL tutorial:

```python
context = adal.AuthenticationContext('https://login.microsoftonline.com/ABCDEFGH-1234-1234-1234-ABCDEFGHIJKL')
RESOURCE = '00000002-0000-0000-c000-000000000000' #AAD graph resource
token = context.acquire_token_with_client_credentials(
RESOURCE,
"http://PythonSDK",
"Key-Configured-In-Portal")
```

can be written here:

```python
context = adal.AuthenticationContext('https://login.microsoftonline.com/ABCDEFGH-1234-1234-1234-ABCDEFGHIJKL')
RESOURCE = '00000002-0000-0000-c000-000000000000' #AAD graph resource
credentials = AdalAuthentication(
context.acquire_token_with_client_credentials,
RESOURCE,
"http://PythonSDK",
"Key-Configured-In-Portal")
```

or using a lambda if you prefer:

```python
context = adal.AuthenticationContext('https://login.microsoftonline.com/ABCDEFGH-1234-1234-1234-ABCDEFGHIJKL')
RESOURCE = '00000002-0000-0000-c000-000000000000' #AAD graph resource
credentials = AdalAuthentication(
lambda: context.acquire_token_with_client_credentials(
RESOURCE,
"http://PythonSDK",
"Key-Configured-In-Portal"
)
)
```

:param adal_method: A lambda with no args, or `acquire_token` method with args using args/kwargs
:param args: Optional args for the method
:param kwargs: Optional kwargs for the method
"""
self._adal_method = adal_method
self._args = args
self._kwargs = kwargs

def signed_session(self):
"""Get a signed session for requests.

Usually called by the Azure SDKs for you to authenticate queries.

:rtype: requests.Session
"""
session = super(AdalAuthentication, self).signed_session()

try:
raw_token = self._adal_method(*self._args, **self._kwargs)
except adal.AdalError as err:
# pylint: disable=no-member
if (hasattr(err, 'error_response') and ('error_description' in err.error_response)
and ('AADSTS70008:' in err.error_response['error_description'])):
raise Expired("Credentials have expired due to inactivity.")
else:
raise AuthenticationError(err)
except ConnectionError as err:
raise AuthenticationError('Please ensure you have network connection. Error detail: ' + str(err))

scheme, token = raw_token['tokenType'], raw_token['accessToken']
header = "{} {}".format(scheme, token)
session.headers['Authorization'] = header
return session
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,7 @@
'Topic :: Software Development'],
install_requires=[
"msrest~=0.4.4",
"keyring>=5.6"],
"keyring>=5.6",
"adal~=0.4.0"
],
)
42 changes: 35 additions & 7 deletions test/unittest_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,19 @@

from requests_oauthlib import OAuth2Session
import oauthlib
import adal

from msrestazure import AzureConfiguration
from msrestazure import azure_active_directory
from msrestazure.azure_active_directory import (
AADMixin,
InteractiveCredentials,
ServicePrincipalCredentials,
UserPassCredentials
)
from msrest.exceptions import (
TokenExpiredError,
AuthenticationError,
)
UserPassCredentials,
AdalAuthentication
)
from msrest.exceptions import TokenExpiredError, AuthenticationError
from requests import ConnectionError


class TestInteractiveCredentials(unittest.TestCase):
Expand Down Expand Up @@ -356,6 +356,34 @@ def test_user_pass_credentials(self):
client_id="client_id", username='my_username',
password='my_password', resource='https://management.core.chinacloudapi.cn/', verify=False)

def test_adal_authentication(self):
def success_auth():
return {
'tokenType': 'https',
'accessToken': 'cryptictoken'
}

credentials = AdalAuthentication(success_auth)
session = credentials.signed_session()
self.assertEquals(session.headers['Authorization'], 'https cryptictoken')

def error():
raise adal.AdalError("You hacker", {})
credentials = AdalAuthentication(error)
with self.assertRaises(AuthenticationError) as cm:
session = credentials.signed_session()

def expired():
raise adal.AdalError("Too late", {'error_description': "AADSTS70008: Expired"})
credentials = AdalAuthentication(expired)
with self.assertRaises(TokenExpiredError) as cm:
session = credentials.signed_session()

def connection_error():
raise ConnectionError("Plug the network")
credentials = AdalAuthentication(connection_error)
with self.assertRaises(AuthenticationError) as cm:
session = credentials.signed_session()

if __name__ == '__main__':
unittest.main()
unittest.main()