Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from __future__ import absolute_import, division, generators, unicode_literals, print_function, nested_scopes, \
with_statement

from tests.assets import TEST_KEY_RSA_4096, TEST_KEY_ECDSA, TEST_KEY_RSA_2048_ENCRYPTED, POLICY_CLOUD1, POLICY_TPP1, \
EXAMPLE_CSR, EXAMPLE_CHAIN
from vcert import CloudConnection, CertificateRequest, TPPConnection, FakeConnection, ZoneConfig, RevocationRequest, \
TPPTokenConnection
from vcert.common import CertField, KeyType
Expand All @@ -40,7 +38,8 @@
from cryptography.hazmat.primitives import serialization, hashes
logging.basicConfig(level=logging.DEBUG)
logging.getLogger("urllib3").setLevel(logging.DEBUG)

from assets import TEST_KEY_RSA_4096, TEST_KEY_ECDSA, TEST_KEY_RSA_2048_ENCRYPTED, POLICY_CLOUD1, POLICY_TPP1,\
EXAMPLE_CSR, EXAMPLE_CHAIN

FAKE = environ.get('FAKE')
TOKEN = environ.get('CLOUD_APIKEY')
Expand Down
263 changes: 246 additions & 17 deletions vcert/connection_tpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,46 +17,66 @@
from __future__ import (absolute_import, division, generators, unicode_literals, print_function, nested_scopes,
with_statement)

import base64
import logging as log
import re
import time

import requests
from cryptography.hazmat.backends import default_backend
from cryptography import x509
from cryptography.x509 import SignatureAlgorithmOID as algos

from .common import MIME_JSON
from .connection_tpp_common import TPPCommonConnection, URLS
from .errors import (ClientBadData, AuthenticationError, ServerUnexptedBehavior)
from .common import CommonConnection, MIME_JSON, CertField, ZoneConfig, Policy, KeyType
from .pem import parse_pem
from .errors import (ServerUnexptedBehavior, ClientBadData, CertificateRequestError, AuthenticationError,
CertificateRenewError)
from .http import HTTPStatus


class URLS:
API_BASE_URL = ""

AUTHORIZE = "authorize/"
CERTIFICATE_REQUESTS = "certificates/request"
CERTIFICATE_RETRIEVE = "certificates/retrieve"
FIND_POLICY = "config/findpolicy"
CERTIFICATE_REVOKE = "certificates/revoke"
CERTIFICATE_RENEW = "certificates/renew"
CERTIFICATE_SEARCH = "certificates/"
CERTIFICATE_IMPORT = "certificates/import"
ZONE_CONFIG = "certificates/checkpolicy"
CONFIG_READ_DN = "Config/ReadDn"


TOKEN_HEADER_NAME = "x-venafi-api-key" # nosec


class TPPConnection(TPPCommonConnection):
class TPPConnection(CommonConnection):
def __init__(self, user, password, url, http_request_kwargs=None):
"""
:param str user:
:param str password:
:param str url:
:param dict[str,Any] http_request_kwargs:
"""
super().__init__(http_request_kwargs=http_request_kwargs)
self._base_url = url # type: str
self._user = user # type: str
self._password = password # type: str
self._token = None # type: tuple
if http_request_kwargs is None:
http_request_kwargs = {"timeout": 180}
elif "timeout" not in http_request_kwargs:
http_request_kwargs["timeout"] = 180
self._http_request_kwargs = http_request_kwargs or {}

def _create_url_dictionary(self):
self.urls = dict()
self.urls[URLS.CERTIFICATE_REQUESTS] = URLS.CERTIFICATE_REQUESTS
self.urls[URLS.CERTIFICATE_RETRIEVE] = URLS.CERTIFICATE_RETRIEVE
self.urls[URLS.FIND_POLICY] = URLS.FIND_POLICY
self.urls[URLS.CERTIFICATE_REVOKE] = URLS.CERTIFICATE_REVOKE
self.urls[URLS.CERTIFICATE_RENEW] = URLS.CERTIFICATE_RENEW
self.urls[URLS.CERTIFICATE_SEARCH] = URLS.CERTIFICATE_SEARCH
self.urls[URLS.CERTIFICATE_IMPORT] = URLS.CERTIFICATE_IMPORT
self.urls[URLS.ZONE_CONFIG] = URLS.ZONE_CONFIG
self.urls[URLS.CONFIG_READ_DN] = URLS.CONFIG_READ_DN
def __setattr__(self, key, value):
if key == "_base_url":
value = self._normalize_and_verify_base_url(value)
self.__dict__[key] = value
Comment on lines +73 to +76
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a handy Python wrapper called @Property. This can be handy here. It would look like this:
@Property
def base_url(self):
# This is a getter
return self._base_url

@base_url.setter
def base_url(self, value):
# This is the setter method
self._base_url = self._normalize_and_verify_base_url(value)

It's nicer for refactoring and is pretty explicit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this suggestion @HELGAHR , we will look into it.


def __str__(self):
return "[TPP] %s" % self._base_url

def _get(self, url="", params=None):
if not self._token or self._token[1] < time.time() + 1:
Expand Down Expand Up @@ -109,11 +129,220 @@ def auth(self):
log.error("Authentication status is not %s but %s. Exiting" % (HTTPStatus.OK, status[0]))
raise AuthenticationError

# TODO: Need to add service generated CSR implementation
def request_cert(self, request, zone):
if not request.csr:
request.build_csr()
request_data = {"PolicyDN": self._get_policy_dn(zone),
"PKCS10": request.csr,
"ObjectName": request.friendly_name,
"DisableAutomaticRenewal": "true"}
if request.origin:
request_data["Origin"] = request.origin
ca_origin = {"Name": "Origin", "Value": request.origin}
if request_data.get("CASpecificAttributes"):
request_data["CASpecificAttributes"].append(ca_origin)
else:
request_data["CASpecificAttributes"] = [ca_origin]
status, data = self._post(URLS.CERTIFICATE_REQUESTS, data=request_data)
if status == HTTPStatus.OK:
request.id = data['CertificateDN']
log.debug("Certificate sucessfully requested with request id %s." % request.id)
return True

log.error("Request status is not %s. %s." % HTTPStatus.OK, status)
raise CertificateRequestError

def retrieve_cert(self, certificate_request):
log.debug("Getting certificate status for id %s" % certificate_request.id)

retrive_request = dict(CertificateDN=certificate_request.id, Format="base64", IncludeChain='true')

if certificate_request.chain_option == "last":
retrive_request['RootFirstOrder'] = 'false'
retrive_request['IncludeChain'] = 'true'
elif certificate_request.chain_option == "first":
retrive_request['RootFirstOrder'] = 'true'
retrive_request['IncludeChain'] = 'true'
elif certificate_request.chain_option == "ignore":
retrive_request['IncludeChain'] = 'false'
else:
log.error("chain option %s is not valid" % certificate_request.chain_option)
raise ClientBadData

status, data = self._post(URLS.CERTIFICATE_RETRIEVE, data=retrive_request)
if status == HTTPStatus.OK:
pem64 = data['CertificateData']
pem = base64.b64decode(pem64)
return parse_pem(pem.decode(), certificate_request.chain_option)
elif status == HTTPStatus.ACCEPTED:
log.debug(data['Status'])
return None

log.error("Status is not %s. %s" % HTTPStatus.OK, status)
raise ServerUnexptedBehavior

def revoke_cert(self, request):
if not (request.id or request.thumbprint):
raise ClientBadData
d = {
"Disable": request.disable
}
if request.reason:
d["Reason"] = request.reason
if request.id:
d["CertificateDN"] = request.id
elif request.thumbprint:
d["Thumbprint"] = request.thumbprint
else:
raise ClientBadData
if request.comments:
d["Comments"] = request.comments
status, data = self._post(URLS.CERTIFICATE_REVOKE, data=d)
if status in (HTTPStatus.OK, HTTPStatus.ACCEPTED):
return data

raise ServerUnexptedBehavior

def renew_cert(self, request, reuse_key=False):
if not request.id and not request.thumbprint:
log.debug("Request id or thumbprint must be specified for TPP")
raise CertificateRenewError
if not request.id and request.thumbprint:
request.id = self.search_by_thumbprint(request.thumbprint)
if reuse_key:
log.debug("Trying to renew certificate %s" % request.id)
status, data = self._post(URLS.CERTIFICATE_RENEW, data={"CertificateDN": request.id})
if not data['Success']:
raise CertificateRenewError
return
cert = self.retrieve_cert(request)
cert = x509.load_pem_x509_certificate(cert.cert.encode(), default_backend())
for a in cert.subject:
if a.oid == x509.NameOID.COMMON_NAME:
request.common_name = a.value
elif a.oid == x509.NameOID.COUNTRY_NAME:
request.country = a.value
elif a.oid == x509.NameOID.LOCALITY_NAME:
request.locality = a.value
elif a.oid == x509.NameOID.STATE_OR_PROVINCE_NAME:
request.province = a.value
elif a.oid == x509.NameOID.ORGANIZATION_NAME:
request.organization = a.value
elif a.oid == x509.NameOID.ORGANIZATIONAL_UNIT_NAME:
request.organizational_unit = a.value
for e in cert.extensions:
if e.oid == x509.OID_SUBJECT_ALTERNATIVE_NAME:
request.san_dns = list([x.value for x in e.value if isinstance(x, x509.DNSName)])
request.email_addresses = list([x.value for x in e.value if isinstance(x, x509.RFC822Name)])
request.ip_addresses = list([x.value.exploded for x in e.value if isinstance(x, x509.IPAddress)])
if cert.signature_algorithm_oid in (algos.ECDSA_WITH_SHA1, algos.ECDSA_WITH_SHA224, algos.ECDSA_WITH_SHA256,
algos.ECDSA_WITH_SHA384, algos.ECDSA_WITH_SHA512):
request.key_type = (KeyType.ECDSA, KeyType.ALLOWED_CURVES[0])
else:
request.key_type = KeyType(KeyType.RSA, 2048) # todo: make parsing key size
if not request.csr:
request.build_csr()
status, data = self._post(URLS.CERTIFICATE_RENEW,
data={"CertificateDN": request.id, "PKCS10": request.csr})
if status == HTTPStatus.OK:
if "CertificateDN" in data:
request.id = data['CertificateDN']
log.debug("Certificate successfully requested with request id %s." % request.id)
return True

log.error("Request status is not %s. %s." % HTTPStatus.OK, status)
raise CertificateRequestError

@staticmethod
def _parse_zone_config_to_policy(data):
# todo: parse over values to regexps (dont forget tests!)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

address todo now ?? as these are lot of if, else in here

p = data["Policy"]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How safe is it in this method to assume that these dictionary keys resolve? I'm new to this code, but I usually think thrice before trying to access a node in the dictionary without .get().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this suggestion as well. Although the code itself was not added in this PR (it was present before), it is nice to have potentially dangerous code design choices pinpointed

if p["KeyPair"]["KeyAlgorithm"]["Locked"]:
if p["KeyPair"]["KeyAlgorithm"]["Value"] == "RSA":
if p["KeyPair"]["KeySize"]["Locked"]:
key_types = [KeyType(KeyType.RSA, p["KeyPair"]["KeySize"]["Value"])]
else:
key_types = [KeyType(KeyType.RSA, x) for x in KeyType.ALLOWED_SIZES]
elif p["KeyPair"]["KeyAlgorithm"]["Value"] == "ECC":
if p["KeyPair"]["EllipticCurve"]["Locked"]:
key_types = [KeyType(KeyType.ECDSA, p["KeyPair"]["EllipticCurve"]["Value"])]
else:
key_types = [KeyType(KeyType.ECDSA, x) for x in KeyType.ALLOWED_CURVES]
else:
raise ServerUnexptedBehavior
else:
key_types = []
if p["KeyPair"].get("KeySize", {}).get("Locked"):
key_types += [KeyType(KeyType.RSA, p["KeyPair"]["KeySize"]["Value"])]
else:
key_types += [KeyType(KeyType.RSA, x) for x in KeyType.ALLOWED_SIZES]
if p["KeyPair"].get("EllipticCurve", {}).get("Locked"):
key_types += [KeyType(KeyType.ECDSA, p["KeyPair"]["EllipticCurve"]["Value"])]
else:
key_types += [KeyType(KeyType.ECDSA, x) for x in KeyType.ALLOWED_CURVES]
return Policy(key_types=key_types)

@staticmethod
def _parse_zone_data_to_object(data):
s = data["Policy"]["Subject"]
ou = s['OrganizationalUnit'].get('Values')
policy = TPPConnection._parse_zone_config_to_policy(data)
if data["Policy"]["KeyPair"]["KeyAlgorithm"]["Value"] == "RSA":
key_type = KeyType(KeyType.RSA, data["Policy"]["KeyPair"]["KeySize"]["Value"])
elif data["Policy"]["KeyPair"]["KeyAlgorithm"]["Value"] == "ECC":
key_type = KeyType(KeyType.ECDSA, data["Policy"]["KeyPair"]["EllipticCurve"]["Value"])
else:
key_type = None
z = ZoneConfig(
organization=CertField(s['Organization']['Value'], locked=s['Organization']['Locked']),
organizational_unit=CertField(ou, locked=s['OrganizationalUnit']['Locked']),
country=CertField(s['Country']['Value'], locked=s['Country']['Locked']),
province=CertField(s['State']['Value'], locked=s['State']['Locked']),
locality=CertField(s['City']['Value'], locked=s['City']['Locked']),
policy=policy,
key_type=key_type,
)
return z

def read_zone_conf(self, tag):
status, data = self._post(URLS.ZONE_CONFIG, {"PolicyDN": self._get_policy_dn(tag)})
if status != HTTPStatus.OK:
raise ServerUnexptedBehavior("Server returns %d status on reading zone configuration." % status)
return self._parse_zone_data_to_object(data)

def import_cert(self, request):
raise NotImplementedError

@staticmethod
def _get_policy_dn(zone):
if zone is None:
log.error("Bad zone: %s" % zone)
raise ClientBadData
if re.match(r"^\\\\VED\\\\Policy", zone):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

regex match API would be lot better here and will remove lot of duplicate code below w.r.t regex match

return zone
else:
if re.match(r"^\\\\", zone):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a tidbit of input: Python string objects have a .startswith() method that's easier to read than a regex, although a regex works fine.

return r"\\VED\\Policy" + zone
else:
return r"\\VED\\Policy\\" + zone

def search_by_thumbprint(self, thumbprint):
"""
:param str thumbprint:
"""
thumbprint = re.sub(r'[^\dabcdefABCDEF]', "", thumbprint)
thumbprint = thumbprint.upper()
status, data = self._get(URLS.CERTIFICATE_SEARCH, params={"Thumbprint": thumbprint})
if status != HTTPStatus.OK:
raise ServerUnexptedBehavior

if not data['Certificates']:
raise ClientBadData("Certificate not found by thumbprint")
return data['Certificates'][0]['DN']

def _read_config_dn(self, dn, attribute_name):
status, data = self._post(self.urls[URLS.CONFIG_READ_DN], {
status, data = self._post(URLS.CONFIG_READ_DN, {
"ObjectDN": dn,
"AttributeName": attribute_name,
})
Expand Down
Loading