diff --git a/msal_extensions/windows.py b/msal_extensions/windows.py index ba5870d..528d51f 100644 --- a/msal_extensions/windows.py +++ b/msal_extensions/windows.py @@ -8,6 +8,7 @@ from .cache_lock import CrossPlatLock _LOCAL_FREE = ctypes.windll.kernel32.LocalFree +_GET_LAST_ERROR = ctypes.windll.kernel32.GetLastError _MEMCPY = ctypes.cdll.msvcrt.memcpy _CRYPT_PROTECT_DATA = ctypes.windll.crypt32.CryptProtectData _CRYPT_UNPROTECT_DATA = ctypes.windll.crypt32.CryptUnprotectData @@ -81,7 +82,9 @@ def protect(self, message): return result.raw() finally: _LOCAL_FREE(result.pbData) - return b'' + + err_code = _GET_LAST_ERROR() + raise OSError(256, '', '', err_code) def unprotect(self, cipher_text): # type: (bytes) -> str @@ -109,7 +112,8 @@ def unprotect(self, cipher_text): return result.raw().decode('utf-8') finally: _LOCAL_FREE(result.pbData) - return u'' + err_code = _GET_LAST_ERROR() + raise OSError(256, '', '', err_code) class WindowsTokenCache(msal.SerializableTokenCache): @@ -118,7 +122,7 @@ class WindowsTokenCache(msal.SerializableTokenCache): """ def __init__(self, cache_location=os.path.join( - os.getenv('LOCALAPPDATA'), + os.getenv('LOCALAPPDATA', os.path.expanduser('~')), '.IdentityService', 'msal.cache'), entropy=''): @@ -137,7 +141,7 @@ def _needs_refresh(self): """ try: return self._last_sync < os.path.getmtime(self._cache_location) - except OSError as exp: + except IOError as exp: if exp.errno != errno.ENOENT: raise exp return False @@ -147,7 +151,7 @@ def add(self, event, **kwargs): if self._needs_refresh(): try: self._read() - except OSError as exp: + except IOError as exp: if exp.errno != errno.ENOENT: raise exp super(WindowsTokenCache, self).add(event, **kwargs) @@ -158,7 +162,7 @@ def update_rt(self, rt_item, new_rt): if self._needs_refresh(): try: self._read() - except OSError as exp: + except IOError as exp: if exp.errno != errno.ENOENT: raise exp super(WindowsTokenCache, self).update_rt(rt_item, new_rt) @@ -169,7 +173,7 @@ def remove_rt(self, rt_item): if self._needs_refresh(): try: self._read() - except OSError as exp: + except IOError as exp: if exp.errno != errno.ENOENT: raise exp super(WindowsTokenCache, self).remove_rt(rt_item) @@ -180,7 +184,7 @@ def find(self, credential_type, **kwargs): # pylint: disable=arguments-differ if self._needs_refresh(): try: self._read() - except OSError as exp: + except IOError as exp: if exp.errno != errno.ENOENT: raise exp return super(WindowsTokenCache, self).find(credential_type, **kwargs) diff --git a/tests/test_windows_backend.py b/tests/test_windows_backend.py index 2999c97..6ff822e 100644 --- a/tests/test_windows_backend.py +++ b/tests/test_windows_backend.py @@ -1,5 +1,6 @@ import sys import os +import errno import shutil import tempfile import pytest @@ -24,18 +25,24 @@ def test_dpapi_roundtrip_with_entropy(): uuid.uuid4().hex, ] - for tc in test_cases: - ciphered = subject_with_entropy.protect(tc) - assert ciphered != tc + try: + for tc in test_cases: + ciphered = subject_with_entropy.protect(tc) + assert ciphered != tc - got = subject_with_entropy.unprotect(ciphered) - assert got == tc + got = subject_with_entropy.unprotect(ciphered) + assert got == tc - ciphered = subject_without_entropy.protect(tc) - assert ciphered != tc + ciphered = subject_without_entropy.protect(tc) + assert ciphered != tc - got = subject_without_entropy.unprotect(ciphered) - assert got == tc + got = subject_without_entropy.unprotect(ciphered) + assert got == tc + except OSError as exp: + if exp.errno == errno.EIO and os.getenv('TRAVIS_REPO_SLUG'): + pytest.skip('DPAPI tests are known to fail in TravisCI. This effort tracked by ' + 'https://github.com/AzureAD/microsoft-authentication-extentions-for-python' + '/issues/21') def test_read_msal_cache_direct(): @@ -43,10 +50,11 @@ def test_read_msal_cache_direct(): This loads and unprotects an MSAL cache directly, only using the DataProtectionAgent. It is not meant to test the wrapper `WindowsTokenCache`. """ + localappdata_location = os.getenv('LOCALAPPDATA', os.path.expanduser('~')) cache_locations = [ - os.path.join(os.getenv('LOCALAPPDATA'), '.IdentityService', 'msal.cache'), # this is where it's supposed to be - os.path.join(os.getenv('LOCALAPPDATA'), '.IdentityServices', 'msal.cache'), # There was a miscommunications about whether this was plural or not. - os.path.join(os.getenv('LOCALAPPDATA'), 'msal.cache'), # The earliest most naive builds used this locations. + os.path.join(localappdata_location, '.IdentityService', 'msal.cache'), # this is where it's supposed to be + os.path.join(localappdata_location, '.IdentityServices', 'msal.cache'), # There was a miscommunications about whether this was plural or not. + os.path.join(localappdata_location, 'msal.cache'), # The earliest most naive builds used this locations. ] found = False @@ -55,9 +63,12 @@ def test_read_msal_cache_direct(): with open(loc, mode='rb') as fh: contents = fh.read() found = True + break - except FileNotFoundError: - pass + except IOError as exp: + if exp.errno != errno.ENOENT: + raise exp + if not found: pytest.skip('could not find the msal.cache file (try logging in using MSAL)')