Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 19 additions & 22 deletions kms/api-client/asymmetric.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.rom googleapiclient import discovery

# [START kms_asymmetric_imports]
import base64
import hashlib

from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec, padding, utils
# [END kms_asymmetric_imports]


# [START kms_get_asymmetric_public]
Expand All @@ -43,35 +45,34 @@ def getAsymmetricPublicKey(client, key_path):
# [START kms_decrypt_rsa]
def decryptRSA(ciphertext, client, key_path):
"""
Decrypt a given ciphertext using an 'RSA_DECRYPT_OAEP_2048_SHA256' private
key stored on Cloud KMS
Decrypt the input ciphertext (bytes) using an
'RSA_DECRYPT_OAEP_2048_SHA256' private key stored on Cloud KMS
"""
request_body = {'ciphertext': base64.b64encode(ciphertext).decode('utf-8')}
request = client.projects() \
.locations() \
.keyRings() \
.cryptoKeys() \
.cryptoKeyVersions() \
.asymmetricDecrypt(name=key_path,
body={'ciphertext': ciphertext})
body=request_body)
response = request.execute()
plaintext = base64.b64decode(response['plaintext']).decode('utf-8')
plaintext = base64.b64decode(response['plaintext'])
return plaintext
# [END kms_decrypt_rsa]


# [START kms_encrypt_rsa]
def encryptRSA(message, client, key_path):
def encryptRSA(plaintext, client, key_path):
"""
Encrypt message locally using an 'RSA_DECRYPT_OAEP_2048_SHA256' public
key retrieved from Cloud KMS
Encrypt the input plaintext (bytes) locally using an
'RSA_DECRYPT_OAEP_2048_SHA256' public key retrieved from Cloud KMS
"""
public_key = getAsymmetricPublicKey(client, key_path)
pad = padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None)
ciphertext = public_key.encrypt(message.encode('ascii'), pad)
ciphertext = base64.b64encode(ciphertext).decode('utf-8')
return ciphertext
return public_key.encrypt(plaintext, pad)
# [END kms_encrypt_rsa]


Expand All @@ -82,7 +83,7 @@ def signAsymmetric(message, client, key_path):
"""
# Note: some key algorithms will require a different hash function
# For example, EC_SIGN_P384_SHA384 requires SHA384
digest_bytes = hashlib.sha256(message.encode('ascii')).digest()
digest_bytes = hashlib.sha256(message).digest()
digest64 = base64.b64encode(digest_bytes)

digest_JSON = {'sha256': digest64.decode('utf-8')}
Expand All @@ -94,24 +95,22 @@ def signAsymmetric(message, client, key_path):
.asymmetricSign(name=key_path,
body={'digest': digest_JSON})
response = request.execute()
return response.get('signature', None)
return base64.b64decode(response.get('signature', None))
# [END kms_sign_asymmetric]


# [START kms_verify_signature_rsa]
def verifySignatureRSA(signature, message, client, key_path):
"""
Verify the validity of an 'RSA_SIGN_PSS_2048_SHA256' signature for the
specified plaintext message
specified message
"""
public_key = getAsymmetricPublicKey(client, key_path)

digest_bytes = hashlib.sha256(message.encode('ascii')).digest()
sig_bytes = base64.b64decode(signature)
digest_bytes = hashlib.sha256(message).digest()

try:
# Attempt verification
public_key.verify(sig_bytes,
public_key.verify(signature,
digest_bytes,
padding.PSS(mgf=padding.MGF1(hashes.SHA256()),
salt_length=32),
Expand All @@ -127,16 +126,14 @@ def verifySignatureRSA(signature, message, client, key_path):
def verifySignatureEC(signature, message, client, key_path):
"""
Verify the validity of an 'EC_SIGN_P256_SHA256' signature
for the specified plaintext message
for the specified message
"""
public_key = getAsymmetricPublicKey(client, key_path)

digest_bytes = hashlib.sha256(message.encode('ascii')).digest()
sig_bytes = base64.b64decode(signature)
digest_bytes = hashlib.sha256(message).digest()

try:
# Attempt verification
public_key.verify(sig_bytes,
public_key.verify(signature,
digest_bytes,
ec.ECDSA(utils.Prehashed(hashes.SHA256())))
# No errors were thrown. Verification was successful
Expand Down
40 changes: 24 additions & 16 deletions kms/api-client/asymmetric_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from os import environ
from time import sleep

Expand Down Expand Up @@ -89,6 +88,7 @@ class TestKMSSamples:
.format(parent, keyring, ecSignId)

message = 'test message 123'
message_bytes = message.encode('utf-8')

client = discovery.build('cloudkms', 'v1')

Expand All @@ -99,44 +99,52 @@ def test_get_public_key(self):
assert isinstance(ec_key, _EllipticCurvePublicKey), 'expected EC key'

def test_rsa_encrypt_decrypt(self):
ciphertext = sample.encryptRSA(self.message,
ciphertext = sample.encryptRSA(self.message_bytes,
self.client,
self.rsaDecrypt)
# ciphertext should be 344 characters with base64 and RSA 2048
assert len(ciphertext) == 344, \
'ciphertext should be 344 chars; got {}'.format(len(ciphertext))
assert ciphertext[-2:] == '==', 'cipher text should end with =='
plaintext = sample.decryptRSA(ciphertext, self.client, self.rsaDecrypt)
# ciphertext should be 256 characters with base64 and RSA 2048
assert len(ciphertext) == 256, \
'ciphertext should be 256 chars; got {}'.format(len(ciphertext))
plaintext_bytes = sample.decryptRSA(ciphertext,
self.client,
self.rsaDecrypt)
assert plaintext_bytes == self.message_bytes
plaintext = plaintext_bytes.decode('utf-8')
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there much value in checking that plaintext == self.message if you've already checked that plaintext_bytes == self.message_bytes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not, but I'd rather err on the side of too many tests, especially when dealing with all these encoding changes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the time you specify utf-8 for string.encode() and bytes.decode(), except on line 51 where you just assume the default (which is utf-8). Both ways work fine, but this is a minor inconsistency.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, we should probably keep that consistent. Fixed

assert plaintext == self.message

def test_rsa_sign_verify(self):
sig = sample.signAsymmetric(self.message, self.client, self.rsaSign)
sig = sample.signAsymmetric(self.message_bytes,
self.client,
self.rsaSign)
# ciphertext should be 344 characters with base64 and RSA 2048
assert len(sig) == 344, \
'sig should be 344 chars; got {}'.format(len(sig))
assert sig[-2:] == '==', 'sig should end with =='
assert len(sig) == 256, \
'sig should be 256 chars; got {}'.format(len(sig))
success = sample.verifySignatureRSA(sig,
self.message,
self.message_bytes,
self.client,
self.rsaSign)
assert success is True, 'RSA verification failed'
changed_bytes = self.message_bytes + b'.'
success = sample.verifySignatureRSA(sig,
self.message+'.',
changed_bytes,
self.client,
self.rsaSign)
assert success is False, 'verify should fail with modified message'

def test_ec_sign_verify(self):
sig = sample.signAsymmetric(self.message, self.client, self.ecSign)
sig = sample.signAsymmetric(self.message_bytes,
self.client,
self.ecSign)
assert len(sig) > 50 and len(sig) < 300, \
'sig outside expected length range'
success = sample.verifySignatureEC(sig,
self.message,
self.message_bytes,
self.client,
self.ecSign)
assert success is True, 'EC verification failed'
changed_bytes = self.message_bytes + b'.'
success = sample.verifySignatureEC(sig,
self.message+'.',
changed_bytes,
self.client,
self.ecSign)
assert success is False, 'verify should fail with modified message'