Skip to content
This repository was archived by the owner on Apr 10, 2024. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions msrestazure/azure_active_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
import re
import time
try:
from urlparse import urlparse, parse_qs
from urlparse import urljoin, urlparse, parse_qs
except ImportError:
from urllib.parse import urlparse, parse_qs
from urllib.parse import urljoin, urlparse, parse_qs

import adal
import keyring
from oauthlib.oauth2 import BackendApplicationClient, LegacyApplicationClient
from oauthlib.oauth2.rfc6749.errors import (
Expand All @@ -42,6 +43,7 @@
from requests import RequestException
import requests_oauthlib as oauth

import msrest.authentication
from msrest.authentication import OAuthTokenAuthentication
from msrest.exceptions import TokenExpiredError as Expired
from msrest.exceptions import (
Expand Down Expand Up @@ -525,3 +527,53 @@ def set_token(self, response_url):
raise_with_traceback(AuthenticationError, "", err)
else:
self.token = token


# Constants related to AAD-based authentication methods.
_TOKEN_ENTRY_TOKEN_TYPE = 'tokenType'
_ACCESS_TOKEN = 'accessToken'
XPLAT_APP_ID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"


class AdalAuthentication(msrest.authentication.Authentication):

"""Base class for adal-derived authentication."""

def __init__(self, client_id,
tenant="common",
auth_endpoint="https://login.microsoftonline.com",
resource="https://management.core.windows.net/"):
"""Handle details common to adal."""
super(AdalAuthentication, self).__init__()
self.client_id = client_id
self.authority = urljoin(auth_endpoint, tenant)
self.resource = resource
context = adal.AuthenticationContext(self.authority)
self.token = self.acquire_token(context)

# @abc.abstractmethod if Python 2 wasn't supported.
def acquire_token(self, context):
"""Override with code that returns an adal.acquire_*() call."""
raise NotImplementedError

def signed_session(self):
"""Return a signed session."""
session = super(AdalAuthentication, self).signed_session()
header = " ".join([self.token[_TOKEN_ENTRY_TOKEN_TYPE],
self.token[_ACCESS_TOKEN]])
session.headers['Authorization'] = header
return session


class AdalUserPassCredentials(AdalAuthentication):

"""Authenticate with AAD using a username and password."""

def __init__(self, username, password, client_id, **kwargs):
self.username = username
self.password = password
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implies that credentials are not tested at the instance creation, but at the instance usage. This is different from current UserPassCredentials behavior.
I think we might want to keep the current behavior, but I'm open to discussion (@annatisch)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

super(AdalUserPassCredentials, self).__init__(client_id, **kwargs)

def acquire_token(self, context):
return context.acquire_token_with_username_password(
self.resource, self.username, self.password, self.client_id)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,6 @@
'Topic :: Software Development'],
install_requires=[
"msrest~=0.4.4",
"adal>=0.4.0",
"keyring>=5.6"],
)
81 changes: 72 additions & 9 deletions test/unittest_auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#--------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
Expand Down Expand Up @@ -31,6 +31,10 @@
from unittest import mock
except ImportError:
import mock
try:
from urlparse import urljoin
except ImportError:
from urllib.parse import urljoin

from requests_oauthlib import OAuth2Session
import oauthlib
Expand All @@ -54,7 +58,7 @@ class TestInteractiveCredentials(unittest.TestCase):
def setUp(self):
self.cfg = AzureConfiguration("https://my_service.com")
return super(TestInteractiveCredentials, self).setUp()

def test_http(self):

test_uri = "http://my_service.com"
Expand Down Expand Up @@ -282,8 +286,8 @@ def test_service_principal(self):
session = mock.create_autospec(OAuth2Session)
with mock.patch.object(
ServicePrincipalCredentials, '_setup_session', return_value=session):
creds = ServicePrincipalCredentials("client_id", "secret",

creds = ServicePrincipalCredentials("client_id", "secret",
verify=False, tenant="private")

session.fetch_token.assert_called_with(
Expand All @@ -294,7 +298,7 @@ def test_service_principal(self):

with mock.patch.object(
ServicePrincipalCredentials, '_setup_session', return_value=session):

creds = ServicePrincipalCredentials("client_id", "secret", china=True,
verify=False, tenant="private")

Expand Down Expand Up @@ -336,8 +340,8 @@ def test_user_pass_credentials(self):
session = mock.create_autospec(OAuth2Session)
with mock.patch.object(
UserPassCredentials, '_setup_session', return_value=session):
creds = UserPassCredentials("my_username", "my_password",

creds = UserPassCredentials("my_username", "my_password",
verify=False, tenant="private", resource='resource')

session.fetch_token.assert_called_with(
Expand All @@ -347,7 +351,7 @@ def test_user_pass_credentials(self):

with mock.patch.object(
UserPassCredentials, '_setup_session', return_value=session):

creds = UserPassCredentials("my_username", "my_password", client_id="client_id",
verify=False, tenant="private", china=True)

Expand All @@ -357,5 +361,64 @@ def test_user_pass_credentials(self):
password='my_password', resource='https://management.core.chinacloudapi.cn/', verify=False)


class TestAdalAuthentication(unittest.TestCase):

"""Test authentication using adal."""

def test_base_init(self):
# Test azure_active_directory.AdalAuthentication.__init__().
endpoint = "https://localhost"
tenant = "test-tenant"
token_data = object() # Sentinel.
class FakeAuth(azure_active_directory.AdalAuthentication):
def acquire_token(self, context):
self.context = context
return token_data

auth = FakeAuth(
azure_active_directory.XPLAT_APP_ID,
auth_endpoint=endpoint,
tenant=tenant)
self.assertEqual(auth.authority, urljoin(endpoint, tenant))
self.assertIs(auth.token, token_data)

@mock.patch("adal.AuthenticationContext")
def test_signed_session(self, AuthMock):
# Test azure_active_directory.AdalAuthentication.signed_session().
sentinel = object()
AuthMock.return_value = sentinel
access_token = "access-token"
token_type = "token-type"
token_data = {azure_active_directory._TOKEN_ENTRY_TOKEN_TYPE: token_type,
azure_active_directory._ACCESS_TOKEN: access_token}
class FakeAuth(azure_active_directory.AdalAuthentication):
def acquire_token(self, context):
self.context = context
return token_data

auth =FakeAuth(azure_active_directory.XPLAT_APP_ID)
token = auth.signed_session()
self.assertEqual(AuthMock.call_args, mock.call(auth.authority))
self.assertEqual(token.headers["Authorization"],
" ".join([token_type, access_token]))
self.assertIs(auth.context, sentinel)

@mock.patch("adal.AuthenticationContext")
def test_username_password(self, mocked_context):
# Test azure_active_directory.AdaluserPassCredentials.__init__() calls
# context.acquire_token_with_username_password().
username = 'msrestazure-test'
password = 'password'
context = mock.Mock()
mocked_context.return_value = context
auth = azure_active_directory.AdalUserPassCredentials(
username,
password,
azure_active_directory.XPLAT_APP_ID)
acquire_call = mock.call.acquire_token_with_username_password(
auth.resource, username, password, auth.client_id)
self.assertEqual(context.method_calls, [acquire_call])


if __name__ == '__main__':
unittest.main()
unittest.main()