From 88e9b22fc581a89b8a80241e57578ab4cb9b8370 Mon Sep 17 00:00:00 2001 From: Abhidnya Date: Thu, 22 Oct 2020 09:02:16 -0700 Subject: [PATCH 1/2] Initial commit Code refactoring Changing reference doc string Adding tests PR review 1 Adding dependencies and polishing code Python 2 compat --- msal/application.py | 21 ++++++++++++++++++++- setup.py | 2 ++ tests/test_application.py | 27 +++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/msal/application.py b/msal/application.py index 1769352b..15ef548d 100644 --- a/msal/application.py +++ b/msal/application.py @@ -1,6 +1,11 @@ import functools import json import time + +import six +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization + try: # Python 2 from urlparse import urljoin except: # Python 3 @@ -124,6 +129,7 @@ def __init__( "private_key": "...-----BEGIN PRIVATE KEY-----...", "thumbprint": "A1B2C3D4E5F6...", "public_certificate": "...-----BEGIN CERTIFICATE-----..." (Optional. See below.) + "passphrase": "Passphrase if the private_key is encrypted (Optional)" } *Added in version 0.5.0*: @@ -252,8 +258,21 @@ def _build_client(self, client_credential, authority): headers = {} if 'public_certificate' in client_credential: headers["x5c"] = extract_certs(client_credential['public_certificate']) + if not client_credential.get("passphrase"): + unencrypted_private_key = client_credential['private_key'] + else: + if isinstance(client_credential['private_key'], six.text_type): + private_key = client_credential['private_key'].encode(encoding="utf-8") + else: + private_key = client_credential['private_key'] + if isinstance(client_credential['passphrase'], six.text_type): + password = client_credential['passphrase'].encode(encoding="utf-8") + else: + password = client_credential['passphrase'] + unencrypted_private_key = serialization.load_pem_private_key( + private_key, password=password, backend=default_backend()) assertion = JwtAssertionCreator( - client_credential["private_key"], algorithm="RS256", + unencrypted_private_key, algorithm="RS256", sha1_thumbprint=client_credential.get("thumbprint"), headers=headers) client_assertion = assertion.create_regenerative_assertion( audience=authority.token_endpoint, issuer=self.client_id, diff --git a/setup.py b/setup.py index 960d4bca..4ca79d33 100644 --- a/setup.py +++ b/setup.py @@ -74,6 +74,8 @@ install_requires=[ 'requests>=2.0.0,<3', 'PyJWT[crypto]>=1.0.0,<2', + 'six>=1.6', + 'cryptography>=2.1.4' ] ) diff --git a/tests/test_application.py b/tests/test_application.py index 3281dc04..5751114a 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -39,6 +39,33 @@ def test_extract_multiple_tag_enclosed_certs(self): self.assertEqual(["my_cert1", "my_cert2"], extract_certs(pem)) +class TestEncryptedKeyAsClientCredential(unittest.TestCase): + # Internally, we use serialization.load_pem_private_key() to load an encrypted private key with a passphrase + # This function takes in encrypted key in bytes and passphrase in bytes too + # Our code handles such a conversion, adding test cases to verify such a conversion is needed + + def test_encyrpted_key_in_bytes_and_string_password_should_error(self): + private_key = b""" + -----BEGIN ENCRYPTED PRIVATE KEY----- + test_private_key + -----END ENCRYPTED PRIVATE KEY----- + """ + with self.assertRaises(TypeError): + # Using a unicode string for Python 2 to identify it as a string and not default to bytes + serialization.load_pem_private_key( + private_key, password=u"string_password", backend=default_backend()) + + def test_encyrpted_key_is_string_and_bytes_password_should_error(self): + private_key = u""" + -----BEGIN ENCRYPTED PRIVATE KEY----- + test_private_key + -----END ENCRYPTED PRIVATE KEY----- + """ + with self.assertRaises(TypeError): + serialization.load_pem_private_key( + private_key, password=b"byte_password", backend=default_backend()) + + class TestClientApplicationAcquireTokenSilentErrorBehaviors(unittest.TestCase): def setUp(self): From 2ded2775cb8330848c5f6e4e8ddd474a656e217e Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Thu, 29 Oct 2020 00:18:09 -0700 Subject: [PATCH 2/2] Removing dependency of six Adding missing arguments to api call Use cryptography lower bound as low as 0.6 Add test cases for _str2bytes() Choose cryptography upper bound as <4 --- msal/application.py | 28 ++++++++++++++-------------- setup.py | 12 ++++++++++-- tests/test_application.py | 30 ++++++------------------------ 3 files changed, 30 insertions(+), 40 deletions(-) diff --git a/msal/application.py b/msal/application.py index 15ef548d..cae9013d 100644 --- a/msal/application.py +++ b/msal/application.py @@ -1,11 +1,6 @@ import functools import json import time - -import six -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import serialization - try: # Python 2 from urlparse import urljoin except: # Python 3 @@ -95,6 +90,14 @@ def _merge_claims_challenge_and_capabilities(capabilities, claims_challenge): return json.dumps(claims_dict) +def _str2bytes(raw): + # A conversion based on duck-typing rather than six.text_type + try: + return raw.encode(encoding="utf-8") + except: + return raw + + class ClientApplication(object): ACQUIRE_TOKEN_SILENT_ID = "84" @@ -261,16 +264,13 @@ def _build_client(self, client_credential, authority): if not client_credential.get("passphrase"): unencrypted_private_key = client_credential['private_key'] else: - if isinstance(client_credential['private_key'], six.text_type): - private_key = client_credential['private_key'].encode(encoding="utf-8") - else: - private_key = client_credential['private_key'] - if isinstance(client_credential['passphrase'], six.text_type): - password = client_credential['passphrase'].encode(encoding="utf-8") - else: - password = client_credential['passphrase'] + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.backends import default_backend unencrypted_private_key = serialization.load_pem_private_key( - private_key, password=password, backend=default_backend()) + _str2bytes(client_credential["private_key"]), + _str2bytes(client_credential["passphrase"]), + backend=default_backend(), # It was a required param until 2020 + ) assertion = JwtAssertionCreator( unencrypted_private_key, algorithm="RS256", sha1_thumbprint=client_credential.get("thumbprint"), headers=headers) diff --git a/setup.py b/setup.py index 4ca79d33..51c988dd 100644 --- a/setup.py +++ b/setup.py @@ -74,8 +74,16 @@ install_requires=[ 'requests>=2.0.0,<3', 'PyJWT[crypto]>=1.0.0,<2', - 'six>=1.6', - 'cryptography>=2.1.4' + + 'cryptography>=0.6,<4', + # load_pem_private_key() is available since 0.6 + # https://github.com/pyca/cryptography/blob/master/CHANGELOG.rst#06---2014-09-29 + # + # Not sure what should be used as an upper bound here + # https://github.com/pyca/cryptography/issues/5532 + # We will go with "<4" for now, which is also what our another dependency, + # pyjwt, currently use. + ] ) diff --git a/tests/test_application.py b/tests/test_application.py index 5751114a..8d48a0ac 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1,6 +1,7 @@ # Note: Since Aug 2019 we move all e2e tests into test_e2e.py, # so this test_application file contains only unit tests without dependency. from msal.application import * +from msal.application import _str2bytes import msal from msal.application import _merge_claims_challenge_and_capabilities from tests import unittest @@ -39,31 +40,12 @@ def test_extract_multiple_tag_enclosed_certs(self): self.assertEqual(["my_cert1", "my_cert2"], extract_certs(pem)) -class TestEncryptedKeyAsClientCredential(unittest.TestCase): - # Internally, we use serialization.load_pem_private_key() to load an encrypted private key with a passphrase - # This function takes in encrypted key in bytes and passphrase in bytes too - # Our code handles such a conversion, adding test cases to verify such a conversion is needed +class TestBytesConversion(unittest.TestCase): + def test_string_to_bytes(self): + self.assertEqual(type(_str2bytes("some string")), type(b"bytes")) - def test_encyrpted_key_in_bytes_and_string_password_should_error(self): - private_key = b""" - -----BEGIN ENCRYPTED PRIVATE KEY----- - test_private_key - -----END ENCRYPTED PRIVATE KEY----- - """ - with self.assertRaises(TypeError): - # Using a unicode string for Python 2 to identify it as a string and not default to bytes - serialization.load_pem_private_key( - private_key, password=u"string_password", backend=default_backend()) - - def test_encyrpted_key_is_string_and_bytes_password_should_error(self): - private_key = u""" - -----BEGIN ENCRYPTED PRIVATE KEY----- - test_private_key - -----END ENCRYPTED PRIVATE KEY----- - """ - with self.assertRaises(TypeError): - serialization.load_pem_private_key( - private_key, password=b"byte_password", backend=default_backend()) + def test_bytes_to_bytes(self): + self.assertEqual(type(_str2bytes(b"some bytes")), type(b"bytes")) class TestClientApplicationAcquireTokenSilentErrorBehaviors(unittest.TestCase):