diff --git a/msal_extensions/__init__.py b/msal_extensions/__init__.py index fcb4bb7..7c5f4a6 100644 --- a/msal_extensions/__init__.py +++ b/msal_extensions/__init__.py @@ -1,2 +1,11 @@ """Provides auxiliary functionality to the `msal` package.""" __version__ = "0.0.1" + +import sys + +if sys.platform.startswith('win'): + from .token_cache import WindowsTokenCache as TokenCache +elif sys.platform.startswith('darwin'): + from .token_cache import OSXTokenCache as TokenCache +else: + from .token_cache import UnencryptedTokenCache as TokenCache diff --git a/msal_extensions/osx.py b/msal_extensions/osx.py new file mode 100644 index 0000000..33f85e9 --- /dev/null +++ b/msal_extensions/osx.py @@ -0,0 +1,253 @@ +# pylint: disable=duplicate-code + +"""Implements a macOS specific TokenCache, and provides auxiliary helper types.""" + +import os +import ctypes as _ctypes + +OS_RESULT = _ctypes.c_int32 + + +class KeychainError(OSError): + """The RuntimeError that will be run when a function interacting with Keychain fails.""" + + ACCESS_DENIED = -128 + NO_SUCH_KEYCHAIN = -25294 + NO_DEFAULT = -25307 + ITEM_NOT_FOUND = -25300 + + def __init__(self, exit_status): + super(KeychainError, self).__init__() + self.exit_status = exit_status + # TODO: pylint: disable=fixme + # use SecCopyErrorMessageString to fetch the appropriate message here. + self.message = \ + '{} ' \ + 'see https://opensource.apple.com/source/CarbonHeaders/CarbonHeaders-18.1/MacErrors.h'\ + .format(self.exit_status) + +def _get_native_location(name): + # type: (str) -> str + """ + Fetches the location of a native MacOS library. + :param name: The name of the library to be loaded. + :return: The location of the library on a MacOS filesystem. + """ + return '/System/Library/Frameworks/{0}.framework/{0}'.format(name) + + +# Load native MacOS libraries +_SECURITY = _ctypes.CDLL(_get_native_location('Security')) +_CORE = _ctypes.CDLL(_get_native_location('CoreFoundation')) + + +# Bind CFRelease from native MacOS libraries. +_CORE_RELEASE = _CORE.CFRelease +_CORE_RELEASE.argtypes = ( + _ctypes.c_void_p, +) + +# Bind SecCopyErrorMessageString from native MacOS libraries. +# https://developer.apple.com/documentation/security/1394686-seccopyerrormessagestring?language=objc +_SECURITY_COPY_ERROR_MESSAGE_STRING = _SECURITY.SecCopyErrorMessageString +_SECURITY_COPY_ERROR_MESSAGE_STRING.argtypes = ( + OS_RESULT, + _ctypes.c_void_p +) +_SECURITY_COPY_ERROR_MESSAGE_STRING.restype = _ctypes.c_char_p + +# Bind SecKeychainOpen from native MacOS libraries. +# https://developer.apple.com/documentation/security/1396431-seckeychainopen +_SECURITY_KEYCHAIN_OPEN = _SECURITY.SecKeychainOpen +_SECURITY_KEYCHAIN_OPEN.argtypes = ( + _ctypes.c_char_p, + _ctypes.POINTER(_ctypes.c_void_p) +) +_SECURITY_KEYCHAIN_OPEN.restype = OS_RESULT + +# Bind SecKeychainCopyDefault from native MacOS libraries. +# https://developer.apple.com/documentation/security/1400743-seckeychaincopydefault?language=objc +_SECURITY_KEYCHAIN_COPY_DEFAULT = _SECURITY.SecKeychainCopyDefault +_SECURITY_KEYCHAIN_COPY_DEFAULT.argtypes = ( + _ctypes.POINTER(_ctypes.c_void_p), +) +_SECURITY_KEYCHAIN_COPY_DEFAULT.restype = OS_RESULT + + +# Bind SecKeychainItemFreeContent from native MacOS libraries. +_SECURITY_KEYCHAIN_ITEM_FREE_CONTENT = _SECURITY.SecKeychainItemFreeContent +_SECURITY_KEYCHAIN_ITEM_FREE_CONTENT.argtypes = ( + _ctypes.c_void_p, + _ctypes.c_void_p, +) +_SECURITY_KEYCHAIN_ITEM_FREE_CONTENT.restype = OS_RESULT + +# Bind SecKeychainItemModifyAttributesAndData from native MacOS libraries. +_SECURITY_KEYCHAIN_ITEM_MODIFY_ATTRIBUTES_AND_DATA = \ + _SECURITY.SecKeychainItemModifyAttributesAndData +_SECURITY_KEYCHAIN_ITEM_MODIFY_ATTRIBUTES_AND_DATA.argtypes = ( + _ctypes.c_void_p, + _ctypes.c_void_p, + _ctypes.c_uint32, + _ctypes.c_void_p, +) +_SECURITY_KEYCHAIN_ITEM_MODIFY_ATTRIBUTES_AND_DATA.restype = OS_RESULT + +# Bind SecKeychainFindGenericPassword from native MacOS libraries. +# https://developer.apple.com/documentation/security/1397301-seckeychainfindgenericpassword?language=objc +_SECURITY_KEYCHAIN_FIND_GENERIC_PASSWORD = _SECURITY.SecKeychainFindGenericPassword +_SECURITY_KEYCHAIN_FIND_GENERIC_PASSWORD.argtypes = ( + _ctypes.c_void_p, + _ctypes.c_uint32, + _ctypes.c_char_p, + _ctypes.c_uint32, + _ctypes.c_char_p, + _ctypes.POINTER(_ctypes.c_uint32), + _ctypes.POINTER(_ctypes.c_void_p), + _ctypes.POINTER(_ctypes.c_void_p), +) +_SECURITY_KEYCHAIN_FIND_GENERIC_PASSWORD.restype = OS_RESULT +# Bind SecKeychainAddGenericPassword from native MacOS +# https://developer.apple.com/documentation/security/1398366-seckeychainaddgenericpassword?language=objc +_SECURITY_KEYCHAIN_ADD_GENERIC_PASSWORD = _SECURITY.SecKeychainAddGenericPassword +_SECURITY_KEYCHAIN_ADD_GENERIC_PASSWORD.argtypes = ( + _ctypes.c_void_p, + _ctypes.c_uint32, + _ctypes.c_char_p, + _ctypes.c_uint32, + _ctypes.c_char_p, + _ctypes.c_uint32, + _ctypes.c_char_p, + _ctypes.POINTER(_ctypes.c_void_p), +) +_SECURITY_KEYCHAIN_ADD_GENERIC_PASSWORD.restype = OS_RESULT + + +class Keychain(object): + """Encapsulates the interactions with a particular MacOS Keychain.""" + def __init__(self, filename=None): + # type: (str) -> None + self._ref = _ctypes.c_void_p() + + if filename: + filename = os.path.expanduser(filename) + self._filename = filename.encode('utf-8') + else: + self._filename = None + + def __enter__(self): + if self._filename: + status = _SECURITY_KEYCHAIN_OPEN(self._filename, self._ref) + else: + status = _SECURITY_KEYCHAIN_COPY_DEFAULT(self._ref) + + if status: + raise OSError(status) + return self + + def __exit__(self, *args): + if self._ref: + _CORE_RELEASE(self._ref) + + def get_generic_password(self, service, account_name): + # type: (str, str) -> str + """Fetch the password associated with a particular service and account. + + :param service: The service that this password is associated with. + :param account_name: The account that this password is associated with. + :return: The value of the password associated with the specified service and account. + """ + service = service.encode('utf-8') + account_name = account_name.encode('utf-8') + + length = _ctypes.c_uint32() + contents = _ctypes.c_void_p() + exit_status = _SECURITY_KEYCHAIN_FIND_GENERIC_PASSWORD( + self._ref, + len(service), + service, + len(account_name), + account_name, + length, + contents, + None, + ) + + if exit_status: + raise KeychainError(exit_status=exit_status) + + value = _ctypes.create_string_buffer(length.value) + _ctypes.memmove(value, contents.value, length.value) + _SECURITY_KEYCHAIN_ITEM_FREE_CONTENT(None, contents) + return value.raw.decode('utf-8') + + def set_generic_password(self, service, account_name, value): + # type: (str, str, str) -> None + """Associate a password with a given service and account. + + :param service: The service to associate this password with. + :param account_name: The account to associate this password with. + :param value: The string that should be used as the password. + """ + service = service.encode('utf-8') + account_name = account_name.encode('utf-8') + value = value.encode('utf-8') + + entry = _ctypes.c_void_p() + find_exit_status = _SECURITY_KEYCHAIN_FIND_GENERIC_PASSWORD( + self._ref, + len(service), + service, + len(account_name), + account_name, + None, + None, + entry, + ) + + if not find_exit_status: + modify_exit_status = _SECURITY_KEYCHAIN_ITEM_MODIFY_ATTRIBUTES_AND_DATA( + entry, + None, + len(value), + value, + ) + if modify_exit_status: + raise KeychainError(exit_status=modify_exit_status) + + elif find_exit_status == KeychainError.ITEM_NOT_FOUND: + add_exit_status = _SECURITY_KEYCHAIN_ADD_GENERIC_PASSWORD( + self._ref, + len(service), + service, + len(account_name), + account_name, + len(value), + value, + None + ) + + if add_exit_status: + raise KeychainError(exit_status=add_exit_status) + else: + raise KeychainError(exit_status=find_exit_status) + + def get_internet_password(self, service, username): + # type: (str, str) -> str + """ Fetches a password associated with a domain and username. + NOTE: THIS IS NOT YET IMPLEMENTED + :param service: The website/service that this password is associated with. + :param username: The account that this password is associated with. + :return: The password that was associated with the given service and username. + """ + raise NotImplementedError() + + def set_internet_password(self, service, username, value): + # type: (str, str, str) -> None + """Sets a password associated with a domain and a username. + NOTE: THIS IS NOT YET IMPLEMENTED + :param service: The website/service that this password is associated with. + :param username: The account that this password is associated with. + :param value: The password that should be associated with the given service and username. + """ + raise NotImplementedError() diff --git a/msal_extensions/token_cache.py b/msal_extensions/token_cache.py new file mode 100644 index 0000000..878ebca --- /dev/null +++ b/msal_extensions/token_cache.py @@ -0,0 +1,170 @@ +"""Generic functions and types for working with a TokenCache that is not platform specific.""" +import os +import sys +import warnings +import time +import errno +import msal +from .cache_lock import CrossPlatLock + +if sys.platform.startswith('win'): + from .windows import WindowsDataProtectionAgent +elif sys.platform.startswith('darwin'): + from .osx import Keychain + +def _mkdir_p(path): + """Creates a directory, and any necessary parents. + + This implementation based on a Stack Overflow question that can be found here: + https://stackoverflow.com/questions/600268/mkdir-p-functionality-in-python + + If the path provided is an existing file, this function raises an exception. + :param path: The directory name that should be created. + """ + try: + os.makedirs(path) + except OSError as exp: + if exp.errno == errno.EEXIST and os.path.isdir(path): + pass + else: + raise + + +class FileTokenCache(msal.SerializableTokenCache): + """Implements basic unprotected SerializableTokenCache to a plain-text file.""" + def __init__(self, + cache_location=os.path.join( + os.getenv('LOCALAPPDATA', os.path.expanduser('~')), + '.IdentityService', + 'msal.cache'), + lock_location=None): + super(FileTokenCache, self).__init__() + self._cache_location = cache_location + self._lock_location = lock_location or self._cache_location + '.lockfile' + self._last_sync = 0 # _last_sync is a Unixtime + + self._cache_location = os.path.expanduser(self._cache_location) + self._lock_location = os.path.expanduser(self._lock_location) + + _mkdir_p(os.path.dirname(self._lock_location)) + _mkdir_p(os.path.dirname(self._cache_location)) + + def _needs_refresh(self): + # type: () -> Bool + """ + Inspects the file holding the encrypted TokenCache to see if a read is necessary. + :return: True if there are changes not reflected in memory, False otherwise. + """ + try: + updated = os.path.getmtime(self._cache_location) + return self._last_sync < updated + except IOError as exp: + if exp.errno != errno.ENOENT: + raise exp + return False + + def _write(self, contents): + # type: (str) -> None + """Handles actually committing the serialized form of this TokenCache to persisted storage. + For types derived of this, class that will be a file, which has the ability to track a last + modified time. + + :param contents: The serialized contents of a TokenCache + """ + with open(self._cache_location, 'w+') as handle: + handle.write(contents) + + def _read(self): + # type: () -> str + """Fetches the contents of a file and invokes deserialization.""" + with open(self._cache_location, 'r') as handle: + return handle.read() + + def add(self, event, **kwargs): + with CrossPlatLock(self._lock_location): + if self._needs_refresh(): + try: + self.deserialize(self._read()) + except IOError as exp: + if exp.errno != errno.ENOENT: + raise + super(FileTokenCache, self).add(event, **kwargs) # pylint: disable=duplicate-code + self._write(self.serialize()) + self._last_sync = os.path.getmtime(self._cache_location) + + def modify(self, credential_type, old_entry, new_key_value_pairs=None): + with CrossPlatLock(self._lock_location): + if self._needs_refresh(): + try: + self.deserialize(self._read()) + except IOError as exp: + if exp.errno != errno.ENOENT: + raise + super(FileTokenCache, self).modify( + credential_type, + old_entry, + new_key_value_pairs=new_key_value_pairs) + self._write(self.serialize()) + self._last_sync = os.path.getmtime(self._cache_location) + + def find(self, credential_type, **kwargs): # pylint: disable=arguments-differ + with CrossPlatLock(self._lock_location): + if self._needs_refresh(): + try: + self.deserialize(self._read()) + except IOError as exp: + if exp.errno != errno.ENOENT: + raise + self._last_sync = time.time() + return super(FileTokenCache, self).find(credential_type, **kwargs) + + +class UnencryptedTokenCache(FileTokenCache): + """An unprotected token cache to default to when no-platform specific option is available.""" + def __init__(self, **kwargs): + warnings.warn("You are using an unprotected token cache, " + "because an encrypted option is not available for {}".format(sys.platform), + RuntimeWarning) + super(UnencryptedTokenCache, self).__init__(**kwargs) + + +class WindowsTokenCache(FileTokenCache): + """A SerializableTokenCache implementation which uses Win32 encryption APIs to protect your + tokens. + """ + def __init__(self, entropy='', **kwargs): + super(WindowsTokenCache, self).__init__(**kwargs) + self._dp_agent = WindowsDataProtectionAgent(entropy=entropy) + + def _write(self, contents): + with open(self._cache_location, 'wb') as handle: + handle.write(self._dp_agent.protect(contents)) + + def _read(self): + with open(self._cache_location, 'rb') as handle: + cipher_text = handle.read() + return self._dp_agent.unprotect(cipher_text) + + +class OSXTokenCache(FileTokenCache): + """A SerializableTokenCache implementation which uses native Keychain libraries to protect your + tokens. + """ + + def __init__(self, + service_name='Microsoft.Developer.IdentityService', + account_name='MSALCache', + **kwargs): + super(OSXTokenCache, self).__init__(**kwargs) + self._service_name = service_name + self._account_name = account_name + + def _read(self): + with Keychain() as locker: + return locker.get_generic_password(self._service_name, self._account_name) + + def _write(self, contents): + with Keychain() as locker: + locker.set_generic_password(self._service_name, self._account_name, contents) + with open(self._cache_location, "w+") as handle: + handle.write('{} {}'.format(os.getpid(), sys.argv[0])) diff --git a/msal_extensions/windows.py b/msal_extensions/windows.py index 0e529f1..479c496 100644 --- a/msal_extensions/windows.py +++ b/msal_extensions/windows.py @@ -1,11 +1,6 @@ """Implements a Windows Specific TokenCache, and provides auxiliary helper types.""" -import os import ctypes from ctypes import wintypes -import time -import errno -import msal -from .cache_lock import CrossPlatLock _LOCAL_FREE = ctypes.windll.kernel32.LocalFree _GET_LAST_ERROR = ctypes.windll.kernel32.GetLastError @@ -114,81 +109,3 @@ def unprotect(self, cipher_text): _LOCAL_FREE(result.pbData) err_code = _GET_LAST_ERROR() raise OSError(256, '', '', err_code) - - -class WindowsTokenCache(msal.SerializableTokenCache): - """A SerializableTokenCache implementation which uses Win32 encryption APIs to protect your - tokens. - """ - def __init__(self, - cache_location=os.path.join( - os.getenv('LOCALAPPDATA', os.path.expanduser('~')), - '.IdentityService', - 'msal.cache'), - entropy=''): - super(WindowsTokenCache, self).__init__() - - self._cache_location = cache_location - self._lock_location = self._cache_location + '.lockfile' - self._dp_agent = WindowsDataProtectionAgent(entropy=entropy) - self._last_sync = 0 # _last_sync is a Unixtime - - def _needs_refresh(self): - # type: () -> Bool - """ - Inspects the file holding the encrypted TokenCache to see if a read is necessary. - :return: True if there are changes not reflected in memory, False otherwise. - """ - try: - return self._last_sync < os.path.getmtime(self._cache_location) - except IOError as exp: - if exp.errno != errno.ENOENT: - raise exp - return False - - def add(self, event, **kwargs): - with CrossPlatLock(self._lock_location): - if self._needs_refresh(): - try: - self._read() - except IOError as exp: - if exp.errno != errno.ENOENT: - raise exp - super(WindowsTokenCache, self).add(event, **kwargs) - self._write() - - def modify(self, credential_type, old_entry, new_key_value_pairs=None): - with CrossPlatLock(self._lock_location): - if self._needs_refresh(): - try: - self._read() - except IOError as exp: - if exp.errno != errno.ENOENT: - raise exp - super(WindowsTokenCache, self).modify( - credential_type, - old_entry, - new_key_value_pairs=new_key_value_pairs) - self._write() - - def find(self, credential_type, **kwargs): # pylint: disable=arguments-differ - with CrossPlatLock(self._lock_location): - if self._needs_refresh(): - try: - self._read() - except IOError as exp: - if exp.errno != errno.ENOENT: - raise exp - return super(WindowsTokenCache, self).find(credential_type, **kwargs) - - def _write(self): - with open(self._cache_location, 'wb') as handle: - handle.write(self._dp_agent.protect(self.serialize())) - self._last_sync = int(time.time()) - - def _read(self): - with open(self._cache_location, 'rb') as handle: - cipher_text = handle.read() - contents = self._dp_agent.unprotect(cipher_text) - self.deserialize(contents) - self._last_sync = int(time.time()) diff --git a/tests/test_agnostic_backend.py b/tests/test_agnostic_backend.py new file mode 100644 index 0000000..1d9c7d0 --- /dev/null +++ b/tests/test_agnostic_backend.py @@ -0,0 +1,54 @@ +import os +import shutil +import tempfile +import pytest +import msal + + +def test_file_token_cache_roundtrip(): + from msal_extensions.token_cache import FileTokenCache + + client_id = os.getenv('AZURE_CLIENT_ID') + client_secret = os.getenv('AZURE_CLIENT_SECRET') + if not (client_id and client_secret): + pytest.skip('no credentials present to test FileTokenCache round-trip with.') + + test_folder = tempfile.mkdtemp(prefix="msal_extension_test_file_token_cache_roundtrip") + cache_file = os.path.join(test_folder, 'msal.cache') + try: + subject = FileTokenCache(cache_location=cache_file) + app = msal.ConfidentialClientApplication( + client_id=client_id, + client_credential=client_secret, + token_cache=subject) + desired_scopes = ['https://graph.microsoft.com/.default'] + token1 = app.acquire_token_for_client(scopes=desired_scopes) + os.utime(cache_file, None) # Mock having another process update the cache. + token2 = app.acquire_token_silent(scopes=desired_scopes, account=None) + assert token1['access_token'] == token2['access_token'] + finally: + shutil.rmtree(test_folder, ignore_errors=True) + + +def test_current_platform_cache_roundtrip(): + from msal_extensions import TokenCache + client_id = os.getenv('AZURE_CLIENT_ID') + client_secret = os.getenv('AZURE_CLIENT_SECRET') + if not (client_id and client_secret): + pytest.skip('no credentials present to test FileTokenCache round-trip with.') + + test_folder = tempfile.mkdtemp(prefix="msal_extension_test_file_token_cache_roundtrip") + cache_file = os.path.join(test_folder, 'msal.cache') + try: + subject = TokenCache(cache_location=cache_file) + app = msal.ConfidentialClientApplication( + client_id=client_id, + client_credential=client_secret, + token_cache=subject) + desired_scopes = ['https://graph.microsoft.com/.default'] + token1 = app.acquire_token_for_client(scopes=desired_scopes) + os.utime(cache_file, None) # Mock having another process update the cache. + token2 = app.acquire_token_silent(scopes=desired_scopes, account=None) + assert token1['access_token'] == token2['access_token'] + finally: + shutil.rmtree(test_folder, ignore_errors=True) diff --git a/tests/test_macos_backend.py b/tests/test_macos_backend.py new file mode 100644 index 0000000..c0ca8e1 --- /dev/null +++ b/tests/test_macos_backend.py @@ -0,0 +1,45 @@ +import sys +import os +import shutil +import tempfile +import pytest +import uuid +import msal + +if not sys.platform.startswith('darwin'): + pytest.skip('skipping MacOS-only tests', allow_module_level=True) +else: + from msal_extensions.osx import Keychain + from msal_extensions.token_cache import OSXTokenCache + + +def test_keychain_roundtrip(): + with Keychain() as subject: + location, account = "msal_extension_test1", "test_account1" + want = uuid.uuid4().hex + subject.set_generic_password(location, account, want) + got = subject.get_generic_password(location, account) + assert got == want + + +def test_osx_token_cache_roundtrip(): + client_id = os.getenv('AZURE_CLIENT_ID') + client_secret = os.getenv('AZURE_CLIENT_SECRET') + if not (client_id and client_secret): + pytest.skip('no credentials present to test OSXTokenCache round-trip with.') + + test_folder = tempfile.mkdtemp(prefix="msal_extension_test_osx_token_cache_roundtrip") + cache_file = os.path.join(test_folder, 'msal.cache') + try: + subject = OSXTokenCache(cache_location=cache_file) + app = msal.ConfidentialClientApplication( + client_id=client_id, + client_credential=client_secret, + token_cache=subject) + desired_scopes = ['https://graph.microsoft.com/.default'] + token1 = app.acquire_token_for_client(scopes=desired_scopes) + os.utime(cache_file, None) # Mock having another process update the cache. + token2 = app.acquire_token_silent(scopes=desired_scopes, account=None) + assert token1['access_token'] == token2['access_token'] + finally: + shutil.rmtree(test_folder, ignore_errors=True) diff --git a/tests/test_windows_backend.py b/tests/test_windows_backend.py index 6ff822e..240b93d 100644 --- a/tests/test_windows_backend.py +++ b/tests/test_windows_backend.py @@ -10,7 +10,8 @@ if not sys.platform.startswith('win'): pytest.skip('skipping windows-only tests', allow_module_level=True) else: - from msal_extensions.windows import WindowsDataProtectionAgent, WindowsTokenCache + from msal_extensions.windows import WindowsDataProtectionAgent + from msal_extensions.token_cache import WindowsTokenCache def test_dpapi_roundtrip_with_entropy():