diff --git a/tests/test_e2e.py b/tests/test_e2e.py index dc3cafd..e30ba63 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -18,8 +18,6 @@ from __future__ import absolute_import, division, generators, unicode_literals, print_function, nested_scopes, \ with_statement -from tests.assets import TEST_KEY_RSA_4096, TEST_KEY_ECDSA, TEST_KEY_RSA_2048_ENCRYPTED, POLICY_CLOUD1, POLICY_TPP1, \ - EXAMPLE_CSR, EXAMPLE_CHAIN from vcert import CloudConnection, CertificateRequest, TPPConnection, FakeConnection, ZoneConfig, RevocationRequest, \ TPPTokenConnection from vcert.common import CertField, KeyType @@ -40,7 +38,8 @@ from cryptography.hazmat.primitives import serialization, hashes logging.basicConfig(level=logging.DEBUG) logging.getLogger("urllib3").setLevel(logging.DEBUG) - +from assets import TEST_KEY_RSA_4096, TEST_KEY_ECDSA, TEST_KEY_RSA_2048_ENCRYPTED, POLICY_CLOUD1, POLICY_TPP1,\ + EXAMPLE_CSR, EXAMPLE_CHAIN FAKE = environ.get('FAKE') TOKEN = environ.get('CLOUD_APIKEY') diff --git a/vcert/connection_tpp.py b/vcert/connection_tpp.py index ff4f663..9003dd7 100644 --- a/vcert/connection_tpp.py +++ b/vcert/connection_tpp.py @@ -17,22 +17,42 @@ from __future__ import (absolute_import, division, generators, unicode_literals, print_function, nested_scopes, with_statement) +import base64 import logging as log import re import time import requests +from cryptography.hazmat.backends import default_backend +from cryptography import x509 +from cryptography.x509 import SignatureAlgorithmOID as algos -from .common import MIME_JSON -from .connection_tpp_common import TPPCommonConnection, URLS -from .errors import (ClientBadData, AuthenticationError, ServerUnexptedBehavior) +from .common import CommonConnection, MIME_JSON, CertField, ZoneConfig, Policy, KeyType +from .pem import parse_pem +from .errors import (ServerUnexptedBehavior, ClientBadData, CertificateRequestError, AuthenticationError, + CertificateRenewError) from .http import HTTPStatus +class URLS: + API_BASE_URL = "" + + AUTHORIZE = "authorize/" + CERTIFICATE_REQUESTS = "certificates/request" + CERTIFICATE_RETRIEVE = "certificates/retrieve" + FIND_POLICY = "config/findpolicy" + CERTIFICATE_REVOKE = "certificates/revoke" + CERTIFICATE_RENEW = "certificates/renew" + CERTIFICATE_SEARCH = "certificates/" + CERTIFICATE_IMPORT = "certificates/import" + ZONE_CONFIG = "certificates/checkpolicy" + CONFIG_READ_DN = "Config/ReadDn" + + TOKEN_HEADER_NAME = "x-venafi-api-key" # nosec -class TPPConnection(TPPCommonConnection): +class TPPConnection(CommonConnection): def __init__(self, user, password, url, http_request_kwargs=None): """ :param str user: @@ -40,23 +60,23 @@ def __init__(self, user, password, url, http_request_kwargs=None): :param str url: :param dict[str,Any] http_request_kwargs: """ - super().__init__(http_request_kwargs=http_request_kwargs) self._base_url = url # type: str self._user = user # type: str self._password = password # type: str self._token = None # type: tuple + if http_request_kwargs is None: + http_request_kwargs = {"timeout": 180} + elif "timeout" not in http_request_kwargs: + http_request_kwargs["timeout"] = 180 + self._http_request_kwargs = http_request_kwargs or {} - def _create_url_dictionary(self): - self.urls = dict() - self.urls[URLS.CERTIFICATE_REQUESTS] = URLS.CERTIFICATE_REQUESTS - self.urls[URLS.CERTIFICATE_RETRIEVE] = URLS.CERTIFICATE_RETRIEVE - self.urls[URLS.FIND_POLICY] = URLS.FIND_POLICY - self.urls[URLS.CERTIFICATE_REVOKE] = URLS.CERTIFICATE_REVOKE - self.urls[URLS.CERTIFICATE_RENEW] = URLS.CERTIFICATE_RENEW - self.urls[URLS.CERTIFICATE_SEARCH] = URLS.CERTIFICATE_SEARCH - self.urls[URLS.CERTIFICATE_IMPORT] = URLS.CERTIFICATE_IMPORT - self.urls[URLS.ZONE_CONFIG] = URLS.ZONE_CONFIG - self.urls[URLS.CONFIG_READ_DN] = URLS.CONFIG_READ_DN + def __setattr__(self, key, value): + if key == "_base_url": + value = self._normalize_and_verify_base_url(value) + self.__dict__[key] = value + + def __str__(self): + return "[TPP] %s" % self._base_url def _get(self, url="", params=None): if not self._token or self._token[1] < time.time() + 1: @@ -109,11 +129,220 @@ def auth(self): log.error("Authentication status is not %s but %s. Exiting" % (HTTPStatus.OK, status[0])) raise AuthenticationError + # TODO: Need to add service generated CSR implementation + def request_cert(self, request, zone): + if not request.csr: + request.build_csr() + request_data = {"PolicyDN": self._get_policy_dn(zone), + "PKCS10": request.csr, + "ObjectName": request.friendly_name, + "DisableAutomaticRenewal": "true"} + if request.origin: + request_data["Origin"] = request.origin + ca_origin = {"Name": "Origin", "Value": request.origin} + if request_data.get("CASpecificAttributes"): + request_data["CASpecificAttributes"].append(ca_origin) + else: + request_data["CASpecificAttributes"] = [ca_origin] + status, data = self._post(URLS.CERTIFICATE_REQUESTS, data=request_data) + if status == HTTPStatus.OK: + request.id = data['CertificateDN'] + log.debug("Certificate sucessfully requested with request id %s." % request.id) + return True + + log.error("Request status is not %s. %s." % HTTPStatus.OK, status) + raise CertificateRequestError + + def retrieve_cert(self, certificate_request): + log.debug("Getting certificate status for id %s" % certificate_request.id) + + retrive_request = dict(CertificateDN=certificate_request.id, Format="base64", IncludeChain='true') + + if certificate_request.chain_option == "last": + retrive_request['RootFirstOrder'] = 'false' + retrive_request['IncludeChain'] = 'true' + elif certificate_request.chain_option == "first": + retrive_request['RootFirstOrder'] = 'true' + retrive_request['IncludeChain'] = 'true' + elif certificate_request.chain_option == "ignore": + retrive_request['IncludeChain'] = 'false' + else: + log.error("chain option %s is not valid" % certificate_request.chain_option) + raise ClientBadData + + status, data = self._post(URLS.CERTIFICATE_RETRIEVE, data=retrive_request) + if status == HTTPStatus.OK: + pem64 = data['CertificateData'] + pem = base64.b64decode(pem64) + return parse_pem(pem.decode(), certificate_request.chain_option) + elif status == HTTPStatus.ACCEPTED: + log.debug(data['Status']) + return None + + log.error("Status is not %s. %s" % HTTPStatus.OK, status) + raise ServerUnexptedBehavior + + def revoke_cert(self, request): + if not (request.id or request.thumbprint): + raise ClientBadData + d = { + "Disable": request.disable + } + if request.reason: + d["Reason"] = request.reason + if request.id: + d["CertificateDN"] = request.id + elif request.thumbprint: + d["Thumbprint"] = request.thumbprint + else: + raise ClientBadData + if request.comments: + d["Comments"] = request.comments + status, data = self._post(URLS.CERTIFICATE_REVOKE, data=d) + if status in (HTTPStatus.OK, HTTPStatus.ACCEPTED): + return data + + raise ServerUnexptedBehavior + + def renew_cert(self, request, reuse_key=False): + if not request.id and not request.thumbprint: + log.debug("Request id or thumbprint must be specified for TPP") + raise CertificateRenewError + if not request.id and request.thumbprint: + request.id = self.search_by_thumbprint(request.thumbprint) + if reuse_key: + log.debug("Trying to renew certificate %s" % request.id) + status, data = self._post(URLS.CERTIFICATE_RENEW, data={"CertificateDN": request.id}) + if not data['Success']: + raise CertificateRenewError + return + cert = self.retrieve_cert(request) + cert = x509.load_pem_x509_certificate(cert.cert.encode(), default_backend()) + for a in cert.subject: + if a.oid == x509.NameOID.COMMON_NAME: + request.common_name = a.value + elif a.oid == x509.NameOID.COUNTRY_NAME: + request.country = a.value + elif a.oid == x509.NameOID.LOCALITY_NAME: + request.locality = a.value + elif a.oid == x509.NameOID.STATE_OR_PROVINCE_NAME: + request.province = a.value + elif a.oid == x509.NameOID.ORGANIZATION_NAME: + request.organization = a.value + elif a.oid == x509.NameOID.ORGANIZATIONAL_UNIT_NAME: + request.organizational_unit = a.value + for e in cert.extensions: + if e.oid == x509.OID_SUBJECT_ALTERNATIVE_NAME: + request.san_dns = list([x.value for x in e.value if isinstance(x, x509.DNSName)]) + request.email_addresses = list([x.value for x in e.value if isinstance(x, x509.RFC822Name)]) + request.ip_addresses = list([x.value.exploded for x in e.value if isinstance(x, x509.IPAddress)]) + if cert.signature_algorithm_oid in (algos.ECDSA_WITH_SHA1, algos.ECDSA_WITH_SHA224, algos.ECDSA_WITH_SHA256, + algos.ECDSA_WITH_SHA384, algos.ECDSA_WITH_SHA512): + request.key_type = (KeyType.ECDSA, KeyType.ALLOWED_CURVES[0]) + else: + request.key_type = KeyType(KeyType.RSA, 2048) # todo: make parsing key size + if not request.csr: + request.build_csr() + status, data = self._post(URLS.CERTIFICATE_RENEW, + data={"CertificateDN": request.id, "PKCS10": request.csr}) + if status == HTTPStatus.OK: + if "CertificateDN" in data: + request.id = data['CertificateDN'] + log.debug("Certificate successfully requested with request id %s." % request.id) + return True + + log.error("Request status is not %s. %s." % HTTPStatus.OK, status) + raise CertificateRequestError + + @staticmethod + def _parse_zone_config_to_policy(data): + # todo: parse over values to regexps (dont forget tests!) + p = data["Policy"] + if p["KeyPair"]["KeyAlgorithm"]["Locked"]: + if p["KeyPair"]["KeyAlgorithm"]["Value"] == "RSA": + if p["KeyPair"]["KeySize"]["Locked"]: + key_types = [KeyType(KeyType.RSA, p["KeyPair"]["KeySize"]["Value"])] + else: + key_types = [KeyType(KeyType.RSA, x) for x in KeyType.ALLOWED_SIZES] + elif p["KeyPair"]["KeyAlgorithm"]["Value"] == "ECC": + if p["KeyPair"]["EllipticCurve"]["Locked"]: + key_types = [KeyType(KeyType.ECDSA, p["KeyPair"]["EllipticCurve"]["Value"])] + else: + key_types = [KeyType(KeyType.ECDSA, x) for x in KeyType.ALLOWED_CURVES] + else: + raise ServerUnexptedBehavior + else: + key_types = [] + if p["KeyPair"].get("KeySize", {}).get("Locked"): + key_types += [KeyType(KeyType.RSA, p["KeyPair"]["KeySize"]["Value"])] + else: + key_types += [KeyType(KeyType.RSA, x) for x in KeyType.ALLOWED_SIZES] + if p["KeyPair"].get("EllipticCurve", {}).get("Locked"): + key_types += [KeyType(KeyType.ECDSA, p["KeyPair"]["EllipticCurve"]["Value"])] + else: + key_types += [KeyType(KeyType.ECDSA, x) for x in KeyType.ALLOWED_CURVES] + return Policy(key_types=key_types) + + @staticmethod + def _parse_zone_data_to_object(data): + s = data["Policy"]["Subject"] + ou = s['OrganizationalUnit'].get('Values') + policy = TPPConnection._parse_zone_config_to_policy(data) + if data["Policy"]["KeyPair"]["KeyAlgorithm"]["Value"] == "RSA": + key_type = KeyType(KeyType.RSA, data["Policy"]["KeyPair"]["KeySize"]["Value"]) + elif data["Policy"]["KeyPair"]["KeyAlgorithm"]["Value"] == "ECC": + key_type = KeyType(KeyType.ECDSA, data["Policy"]["KeyPair"]["EllipticCurve"]["Value"]) + else: + key_type = None + z = ZoneConfig( + organization=CertField(s['Organization']['Value'], locked=s['Organization']['Locked']), + organizational_unit=CertField(ou, locked=s['OrganizationalUnit']['Locked']), + country=CertField(s['Country']['Value'], locked=s['Country']['Locked']), + province=CertField(s['State']['Value'], locked=s['State']['Locked']), + locality=CertField(s['City']['Value'], locked=s['City']['Locked']), + policy=policy, + key_type=key_type, + ) + return z + + def read_zone_conf(self, tag): + status, data = self._post(URLS.ZONE_CONFIG, {"PolicyDN": self._get_policy_dn(tag)}) + if status != HTTPStatus.OK: + raise ServerUnexptedBehavior("Server returns %d status on reading zone configuration." % status) + return self._parse_zone_data_to_object(data) + def import_cert(self, request): raise NotImplementedError + @staticmethod + def _get_policy_dn(zone): + if zone is None: + log.error("Bad zone: %s" % zone) + raise ClientBadData + if re.match(r"^\\\\VED\\\\Policy", zone): + return zone + else: + if re.match(r"^\\\\", zone): + return r"\\VED\\Policy" + zone + else: + return r"\\VED\\Policy\\" + zone + + def search_by_thumbprint(self, thumbprint): + """ + :param str thumbprint: + """ + thumbprint = re.sub(r'[^\dabcdefABCDEF]', "", thumbprint) + thumbprint = thumbprint.upper() + status, data = self._get(URLS.CERTIFICATE_SEARCH, params={"Thumbprint": thumbprint}) + if status != HTTPStatus.OK: + raise ServerUnexptedBehavior + + if not data['Certificates']: + raise ClientBadData("Certificate not found by thumbprint") + return data['Certificates'][0]['DN'] + def _read_config_dn(self, dn, attribute_name): - status, data = self._post(self.urls[URLS.CONFIG_READ_DN], { + status, data = self._post(URLS.CONFIG_READ_DN, { "ObjectDN": dn, "AttributeName": attribute_name, }) diff --git a/vcert/connection_tpp_common.py b/vcert/connection_tpp_common.py deleted file mode 100644 index 46f449c..0000000 --- a/vcert/connection_tpp_common.py +++ /dev/null @@ -1,299 +0,0 @@ -# -# Copyright 2020 Venafi, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -from __future__ import (absolute_import, division, generators, unicode_literals, print_function, nested_scopes, - with_statement) - -import base64 -import logging as log -import re -from abc import ABC, abstractmethod - -from cryptography.hazmat.backends import default_backend -from cryptography import x509 -from cryptography.x509 import SignatureAlgorithmOID as algos - -from .common import CommonConnection, CertField, ZoneConfig, Policy, KeyType -from .pem import parse_pem -from .errors import (ServerUnexptedBehavior, ClientBadData, CertificateRequestError, CertificateRenewError) -from .http import HTTPStatus - - -class URLS: - AUTHORIZE = "authorize/" - CERTIFICATE_REQUESTS = "certificates/request" - CERTIFICATE_RETRIEVE = "certificates/retrieve" - FIND_POLICY = "config/findpolicy" - CERTIFICATE_REVOKE = "certificates/revoke" - CERTIFICATE_RENEW = "certificates/renew" - CERTIFICATE_SEARCH = "certificates/" - CERTIFICATE_IMPORT = "certificates/import" - ZONE_CONFIG = "certificates/checkpolicy" - CONFIG_READ_DN = "Config/ReadDn" - - -class TPPCommonConnection(CommonConnection, ABC): - def __init__(self, http_request_kwargs=None): - """ - :param str url: - :param dict[str,Any] http_request_kwargs: - """ - if http_request_kwargs is None: - http_request_kwargs = {"timeout": 180} - elif "timeout" not in http_request_kwargs: - http_request_kwargs["timeout"] = 180 - self._http_request_kwargs = http_request_kwargs or {} - self.urls = dict() - self._create_url_dictionary() - - def __setattr__(self, key, value): - if key == "_base_url": - value = self._normalize_and_verify_base_url(value) - self.__dict__[key] = value - - def __str__(self): - return "[TPP] %s" % self._base_url - - @abstractmethod - def _create_url_dictionary(self): - raise NotImplementedError - - @staticmethod - @abstractmethod - def _normalize_and_verify_base_url(u): - raise NotImplementedError - - @abstractmethod - def _get(self, url=None, params=None): - raise NotImplementedError - - @abstractmethod - def _post(self, url=None, data=None): - raise NotImplementedError - - # TODO: Need to add service generated CSR implementation - def request_cert(self, request, zone): - if not request.csr: - request.build_csr() - request_data = {"PolicyDN": self._get_policy_dn(zone), - "PKCS10": request.csr, - "ObjectName": request.friendly_name, - "DisableAutomaticRenewal": "true"} - if request.origin: - request_data["Origin"] = request.origin - ca_origin = {"Name": "Origin", "Value": request.origin} - if request_data.get("CASpecificAttributes"): - request_data["CASpecificAttributes"].append(ca_origin) - else: - request_data["CASpecificAttributes"] = [ca_origin] - status, data = self._post(self.urls[URLS.CERTIFICATE_REQUESTS], data=request_data) - if status == HTTPStatus.OK: - request.id = data['CertificateDN'] - log.debug("Certificate sucessfully requested with request id %s." % request.id) - return True - else: - log.error("Request status is not %s. %s." % HTTPStatus.OK, status) - raise CertificateRequestError - - def retrieve_cert(self, certificate_request): - log.debug("Getting certificate status for id %s" % certificate_request.id) - - retrieve_request = dict(CertificateDN=certificate_request.id, Format="base64", IncludeChain='true') - - if certificate_request.chain_option == "last": - retrieve_request['RootFirstOrder'] = 'false' - retrieve_request['IncludeChain'] = 'true' - elif certificate_request.chain_option == "first": - retrieve_request['RootFirstOrder'] = 'true' - retrieve_request['IncludeChain'] = 'true' - elif certificate_request.chain_option == "ignore": - retrieve_request['IncludeChain'] = 'false' - else: - log.error("chain option %s is not valid" % certificate_request.chain_option) - raise ClientBadData - - status, data = self._post(self.urls[URLS.CERTIFICATE_RETRIEVE], data=retrieve_request) - if status == HTTPStatus.OK: - pem64 = data['CertificateData'] - pem = base64.b64decode(pem64) - return parse_pem(pem.decode(), certificate_request.chain_option) - elif status == HTTPStatus.ACCEPTED: - log.debug(data['Status']) - return None - - log.error("Status is not %s. %s" % HTTPStatus.OK, status) - raise ServerUnexptedBehavior - - def revoke_cert(self, request): - if not (request.id or request.thumbprint): - raise ClientBadData - d = { - "Disable": request.disable - } - if request.reason: - d["Reason"] = request.reason - if request.id: - d["CertificateDN"] = request.id - elif request.thumbprint: - d["Thumbprint"] = request.thumbprint - else: - raise ClientBadData - if request.comments: - d["Comments"] = request.comments - status, data = self._post(self.urls[URLS.CERTIFICATE_REVOKE], data=d) - if status in (HTTPStatus.OK, HTTPStatus.ACCEPTED): - return data - else: - raise ServerUnexptedBehavior - - def renew_cert(self, request, reuse_key=False): - if not request.id and not request.thumbprint: - log.debug("Request id or thumbprint must be specified for TPP") - raise CertificateRenewError - if not request.id and request.thumbprint: - request.id = self.search_by_thumbprint(request.thumbprint) - if reuse_key: - log.debug("Trying to renew certificate %s" % request.id) - status, data = self._post(self.urls[URLS.CERTIFICATE_RENEW], data={"CertificateDN": request.id}) - if not data['Success']: - raise CertificateRenewError - return - cert = self.retrieve_cert(request) - cert = x509.load_pem_x509_certificate(cert.cert.encode(), default_backend()) - for a in cert.subject: - if a.oid == x509.NameOID.COMMON_NAME: - request.common_name = a.value - elif a.oid == x509.NameOID.COUNTRY_NAME: - request.country = a.value - elif a.oid == x509.NameOID.LOCALITY_NAME: - request.locality = a.value - elif a.oid == x509.NameOID.STATE_OR_PROVINCE_NAME: - request.province = a.value - elif a.oid == x509.NameOID.ORGANIZATION_NAME: - request.organization = a.value - elif a.oid == x509.NameOID.ORGANIZATIONAL_UNIT_NAME: - request.organizational_unit = a.value - for e in cert.extensions: - if e.oid == x509.OID_SUBJECT_ALTERNATIVE_NAME: - request.san_dns = list([x.value for x in e.value if isinstance(x, x509.DNSName)]) - request.email_addresses = list([x.value for x in e.value if isinstance(x, x509.RFC822Name)]) - request.ip_addresses = list([x.value.exploded for x in e.value if isinstance(x, x509.IPAddress)]) - if cert.signature_algorithm_oid in (algos.ECDSA_WITH_SHA1, algos.ECDSA_WITH_SHA224, algos.ECDSA_WITH_SHA256, - algos.ECDSA_WITH_SHA384, algos.ECDSA_WITH_SHA512): - request.key_type = (KeyType.ECDSA, KeyType.ALLOWED_CURVES[0]) - else: - request.key_type = KeyType(KeyType.RSA, 2048) # todo: make parsing key size - if not request.csr: - request.build_csr() - status, data = self._post(self.urls[URLS.CERTIFICATE_RENEW], - data={"CertificateDN": request.id, "PKCS10": request.csr}) - if status == HTTPStatus.OK: - if "CertificateDN" in data: - request.id = data['CertificateDN'] - log.debug("Certificate successfully requested with request id %s." % request.id) - return True - - log.error("Request status is not %s. %s." % HTTPStatus.OK, status) - raise CertificateRequestError - - @staticmethod - def _parse_zone_config_to_policy(data): - # todo: parse over values to regexps (dont forget tests!) - p = data["Policy"] - if p["KeyPair"]["KeyAlgorithm"]["Locked"]: - if p["KeyPair"]["KeyAlgorithm"]["Value"] == "RSA": - if p["KeyPair"]["KeySize"]["Locked"]: - key_types = [KeyType(KeyType.RSA, p["KeyPair"]["KeySize"]["Value"])] - else: - key_types = [KeyType(KeyType.RSA, x) for x in KeyType.ALLOWED_SIZES] - elif p["KeyPair"]["KeyAlgorithm"]["Value"] == "ECC": - if p["KeyPair"]["EllipticCurve"]["Locked"]: - key_types = [KeyType(KeyType.ECDSA, p["KeyPair"]["EllipticCurve"]["Value"])] - else: - key_types = [KeyType(KeyType.ECDSA, x) for x in KeyType.ALLOWED_CURVES] - else: - raise ServerUnexptedBehavior - else: - key_types = [] - if p["KeyPair"].get("KeySize", {}).get("Locked"): - key_types += [KeyType(KeyType.RSA, p["KeyPair"]["KeySize"]["Value"])] - else: - key_types += [KeyType(KeyType.RSA, x) for x in KeyType.ALLOWED_SIZES] - if p["KeyPair"].get("EllipticCurve", {}).get("Locked"): - key_types += [KeyType(KeyType.ECDSA, p["KeyPair"]["EllipticCurve"]["Value"])] - else: - key_types += [KeyType(KeyType.ECDSA, x) for x in KeyType.ALLOWED_CURVES] - return Policy(key_types=key_types) - - @staticmethod - def _parse_zone_data_to_object(data): - s = data["Policy"]["Subject"] - ou = s['OrganizationalUnit'].get('Values') - policy = TPPCommonConnection._parse_zone_config_to_policy(data) - if data["Policy"]["KeyPair"]["KeyAlgorithm"]["Value"] == "RSA": - key_type = KeyType(KeyType.RSA, data["Policy"]["KeyPair"]["KeySize"]["Value"]) - elif data["Policy"]["KeyPair"]["KeyAlgorithm"]["Value"] == "ECC": - key_type = KeyType(KeyType.ECDSA, data["Policy"]["KeyPair"]["EllipticCurve"]["Value"]) - else: - key_type = None - z = ZoneConfig( - organization=CertField(s['Organization']['Value'], locked=s['Organization']['Locked']), - organizational_unit=CertField(ou, locked=s['OrganizationalUnit']['Locked']), - country=CertField(s['Country']['Value'], locked=s['Country']['Locked']), - province=CertField(s['State']['Value'], locked=s['State']['Locked']), - locality=CertField(s['City']['Value'], locked=s['City']['Locked']), - policy=policy, - key_type=key_type, - ) - return z - - def read_zone_conf(self, tag): - status, data = self._post(self.urls[URLS.ZONE_CONFIG], {"PolicyDN": self._get_policy_dn(tag)}) - if status != HTTPStatus.OK: - raise ServerUnexptedBehavior("Server returns %d status on reading zone configuration." % status) - return self._parse_zone_data_to_object(data) - - @staticmethod - def _get_policy_dn(zone): - if zone is None: - log.error("Bad zone: %s" % zone) - raise ClientBadData - if re.match(r"^\\\\VED\\\\Policy", zone): - return zone - else: - if re.match(r"^\\\\", zone): - return r"\\VED\\Policy" + zone - else: - return r"\\VED\\Policy\\" + zone - - def search_by_thumbprint(self, thumbprint): - """ - :param str thumbprint: - """ - thumbprint = re.sub(r'[^\dabcdefABCDEF]', "", thumbprint) - thumbprint = thumbprint.upper() - status, data = self._get(self.urls[URLS.CERTIFICATE_SEARCH], params={"Thumbprint": thumbprint}) - if status != HTTPStatus.OK: - raise ServerUnexptedBehavior - - if not data['Certificates']: - raise ClientBadData("Certificate not found by thumbprint") - return data['Certificates'][0]['DN'] - - @abstractmethod - def _read_config_dn(self, dn, attribute_name): - raise NotImplementedError diff --git a/vcert/connection_tpp_token.py b/vcert/connection_tpp_token.py index defafd4..74787e9 100644 --- a/vcert/connection_tpp_token.py +++ b/vcert/connection_tpp_token.py @@ -18,77 +18,107 @@ from __future__ import (absolute_import, division, generators, unicode_literals, print_function, nested_scopes, with_statement) +import base64 import logging as log import re import time -from http import HTTPStatus -import requests - -from .common import MIME_JSON, TokenInfo, Authentication -from .connection_tpp_common import TPPCommonConnection, URLS -from .errors import (ClientBadData, ServerUnexptedBehavior, AuthenticationError) +from cryptography.hazmat.backends import default_backend +from cryptography import x509 +from cryptography.x509 import SignatureAlgorithmOID as algos +from .http import HTTPStatus -API_TOKEN_URL = "vedauth/" # type: str -API_BASE_URL = "vedsdk/" # type: str +import requests -PATH_AUTHORIZE_TOKEN = API_TOKEN_URL + "authorize/oauth" # type: str -PATH_REFRESH_TOKEN = API_TOKEN_URL + "authorize/token" # type: str -PATH_REVOKE_TOKEN = API_TOKEN_URL + "revoke/token" # type: str -PATH_CERTIFICATE_GUID = API_BASE_URL + "certificate" # type: str +from .common import MIME_JSON, TokenInfo, Authentication, CommonConnection, KeyType, Policy, ZoneConfig, CertField +from .errors import (ClientBadData, ServerUnexptedBehavior, AuthenticationError, CertificateRequestError, + CertificateRenewError) +from .pem import parse_pem HEADER_AUTHORIZATION = "Authorization" # type: str -KEY_ACCESS_TOKEN = "access_token" # type: str -KEY_REFRESH_TOKEN = "refresh_token" # type: str +KEY_ACCESS_TOKEN = "access_token" # type: str # nosec +KEY_REFRESH_TOKEN = "refresh_token" # type: str # nosec KEY_EXPIRATION_DATE = "expiration_date" # type: str -class TPPTokenConnection(TPPCommonConnection): +class URLS: + API_TOKEN_URL = "vedauth/" # type: str # nosec + API_BASE_URL = "vedsdk/" # type: str # nosec + + AUTHORIZE_TOKEN = API_TOKEN_URL + "authorize/oauth" # type: str + REFRESH_TOKEN = API_TOKEN_URL + "authorize/token" # type: str + REVOKE_TOKEN = API_TOKEN_URL + "revoke/token" # type: str + + AUTHORIZE = API_BASE_URL + "authorize/" + CERTIFICATE_REQUESTS = API_BASE_URL + "certificates/request" + CERTIFICATE_RETRIEVE = API_BASE_URL + "certificates/retrieve" + FIND_POLICY = API_BASE_URL + "config/findpolicy" + CERTIFICATE_REVOKE = API_BASE_URL + "certificates/revoke" + CERTIFICATE_RENEW = API_BASE_URL + "certificates/renew" + CERTIFICATE_SEARCH = API_BASE_URL + "certificates/" + CERTIFICATE_IMPORT = API_BASE_URL + "certificates/import" + ZONE_CONFIG = API_BASE_URL + "certificates/checkpolicy" + CONFIG_READ_DN = API_BASE_URL + "Config/ReadDn" + + +class TPPTokenConnection(CommonConnection): def __init__(self, url, user=None, password=None, access_token=None, refresh_token=None, http_request_kwargs=None): - super().__init__(http_request_kwargs=http_request_kwargs) + """ + :param str url: + :param str user: + :param str password: + :param str access_token: + :param str refresh_token: + :param dict[str,Any] http_request_kwargs: + """ self._base_url = url # type: str self._auth = Authentication(user=user, password=password, access_token=access_token, refresh_token=refresh_token) # type: Authentication + if http_request_kwargs is None: + http_request_kwargs = {"timeout": 180} + elif "timeout" not in http_request_kwargs: + http_request_kwargs["timeout"] = 180 + self._http_request_kwargs = http_request_kwargs or {} + + def __setattr__(self, key, value): + if key == "_base_url": + value = self._normalize_and_verify_base_url(value) + self.__dict__[key] = value + + def __str__(self): + return "[TPP] %s" % self._base_url + + def _get(self, url=None, params=None, check_token=True, include_headers=True): + if check_token: + self._check_token() + + headers = {} + if include_headers: + token = self._get_auth_header_value(self._auth.access_token) + headers = {HEADER_AUTHORIZATION: token, 'content-type': MIME_JSON, 'cache-control': 'no-cache'} + + r = requests.get(self._base_url + url, headers=headers, params=params, **self._http_request_kwargs) + return self.process_server_response(r) - def _create_url_dictionary(self): - self.urls = dict() - self.urls[PATH_AUTHORIZE_TOKEN] = PATH_AUTHORIZE_TOKEN - self.urls[PATH_REFRESH_TOKEN] = PATH_REFRESH_TOKEN - self.urls[PATH_REVOKE_TOKEN] = PATH_REVOKE_TOKEN - - self.urls[URLS.CERTIFICATE_REQUESTS] = API_BASE_URL + URLS.CERTIFICATE_REQUESTS - self.urls[URLS.CERTIFICATE_RETRIEVE] = API_BASE_URL + URLS.CERTIFICATE_RETRIEVE - self.urls[URLS.FIND_POLICY] = API_BASE_URL + URLS.FIND_POLICY - self.urls[URLS.CERTIFICATE_REVOKE] = API_BASE_URL + URLS.CERTIFICATE_REVOKE - self.urls[URLS.CERTIFICATE_RENEW] = API_BASE_URL + URLS.CERTIFICATE_RENEW - self.urls[URLS.CERTIFICATE_SEARCH] = API_BASE_URL + URLS.CERTIFICATE_SEARCH - self.urls[URLS.CERTIFICATE_IMPORT] = API_BASE_URL + URLS.CERTIFICATE_IMPORT - self.urls[URLS.ZONE_CONFIG] = API_BASE_URL + URLS.ZONE_CONFIG - self.urls[URLS.CONFIG_READ_DN] = API_BASE_URL + URLS.CONFIG_READ_DN - - def _get(self, url=None, params=None): - # There is no token - if not self._auth.access_token: - self.get_access_token() - log.debug("Token is %s, expire date is %s" % (self._auth.access_token, self._auth.token_expires)) - - # Token expired, get new token - elif self._auth.token_expires and self._auth.token_expires < time.time(): - if self._auth.refresh_token: - self.refresh_access_token() - log.debug("Token is %s, expire date is %s" % (self._auth.access_token, self._auth.token_expires)) - else: - raise AuthenticationError("Access Token expired. No refresh token provided.") + def _post(self, url=None, data=None, check_token=True, include_headers=True): + if check_token: + self._check_token() - token = TPPTokenConnection._get_auth_header_value(self._auth.access_token) - r = requests.get(self._base_url + url, headers={HEADER_AUTHORIZATION: token, 'content-type': MIME_JSON, - 'cache-control': 'no-cache'}, params=params, **self._http_request_kwargs) + headers = {} + if include_headers: + token = self._get_auth_header_value(self._auth.access_token) + headers = {HEADER_AUTHORIZATION: token, 'content-type': MIME_JSON, "cache-control": "no-cache"} + if isinstance(data, dict): + r = requests.post(self._base_url + url, headers=headers, json=data, **self._http_request_kwargs) + else: + log.error("Unexpected client data type: %s for %s" % (type(data), url)) + raise ClientBadData return self.process_server_response(r) - def _post(self, url=None, data=None): + def _check_token(self): if not self._auth.access_token: self.get_access_token() log.debug("Token is %s, expire date is %s" % (self._auth.access_token, self._auth.token_expires)) @@ -101,15 +131,6 @@ def _post(self, url=None, data=None): else: raise AuthenticationError("Access Token expired. No refresh token provided.") - if isinstance(data, dict): - token = TPPTokenConnection._get_auth_header_value(self._auth.access_token) - r = requests.post(self._base_url + url, headers={HEADER_AUTHORIZATION: token, 'content-type': MIME_JSON, - "cache-control": "no-cache"}, json=data, **self._http_request_kwargs) - else: - log.error("Unexpected client data type: %s for %s" % (type(data), url)) - raise ClientBadData - return self.process_server_response(r) - @staticmethod def _normalize_and_verify_base_url(u): if u.startswith("http://"): @@ -128,6 +149,224 @@ def auth(self): def import_cert(self, request): raise NotImplementedError + # TODO: Need to add service generated CSR implementation + def request_cert(self, request, zone): + if not request.csr: + request.build_csr() + request_data = {"PolicyDN": self._get_policy_dn(zone), + "PKCS10": request.csr, + "ObjectName": request.friendly_name, + "DisableAutomaticRenewal": "true"} + if request.origin: + request_data["Origin"] = request.origin + ca_origin = {"Name": "Origin", "Value": request.origin} + if request_data.get("CASpecificAttributes"): + request_data["CASpecificAttributes"].append(ca_origin) + else: + request_data["CASpecificAttributes"] = [ca_origin] + status, data = self._post(URLS.CERTIFICATE_REQUESTS, data=request_data) + if status == HTTPStatus.OK: + request.id = data['CertificateDN'] + log.debug("Certificate sucessfully requested with request id %s." % request.id) + return True + + log.error("Request status is not %s. %s." % HTTPStatus.OK, status) + raise CertificateRequestError + + def retrieve_cert(self, certificate_request): + log.debug("Getting certificate status for id %s" % certificate_request.id) + + retrive_request = dict(CertificateDN=certificate_request.id, Format="base64", IncludeChain='true') + + if certificate_request.chain_option == "last": + retrive_request['RootFirstOrder'] = 'false' + retrive_request['IncludeChain'] = 'true' + elif certificate_request.chain_option == "first": + retrive_request['RootFirstOrder'] = 'true' + retrive_request['IncludeChain'] = 'true' + elif certificate_request.chain_option == "ignore": + retrive_request['IncludeChain'] = 'false' + else: + log.error("chain option %s is not valid" % certificate_request.chain_option) + raise ClientBadData + + status, data = self._post(URLS.CERTIFICATE_RETRIEVE, data=retrive_request) + if status == HTTPStatus.OK: + pem64 = data['CertificateData'] + pem = base64.b64decode(pem64) + return parse_pem(pem.decode(), certificate_request.chain_option) + elif status == HTTPStatus.ACCEPTED: + log.debug(data['Status']) + return None + + log.error("Status is not %s. %s" % HTTPStatus.OK, status) + raise ServerUnexptedBehavior + + def revoke_cert(self, request): + if not (request.id or request.thumbprint): + raise ClientBadData + d = { + "Disable": request.disable + } + if request.reason: + d["Reason"] = request.reason + if request.id: + d["CertificateDN"] = request.id + elif request.thumbprint: + d["Thumbprint"] = request.thumbprint + else: + raise ClientBadData + if request.comments: + d["Comments"] = request.comments + status, data = self._post(URLS.CERTIFICATE_REVOKE, data=d) + if status in (HTTPStatus.OK, HTTPStatus.ACCEPTED): + return data + else: + raise ServerUnexptedBehavior + + def renew_cert(self, request, reuse_key=False): + if not request.id and not request.thumbprint: + log.debug("Request id or thumbprint must be specified for TPP") + raise CertificateRenewError + if not request.id and request.thumbprint: + request.id = self.search_by_thumbprint(request.thumbprint) + if reuse_key: + log.debug("Trying to renew certificate %s" % request.id) + status, data = self._post(URLS.CERTIFICATE_RENEW, data={"CertificateDN": request.id}) + if not data['Success']: + raise CertificateRenewError + return + cert = self.retrieve_cert(request) + cert = x509.load_pem_x509_certificate(cert.cert.encode(), default_backend()) + for a in cert.subject: + if a.oid == x509.NameOID.COMMON_NAME: + request.common_name = a.value + elif a.oid == x509.NameOID.COUNTRY_NAME: + request.country = a.value + elif a.oid == x509.NameOID.LOCALITY_NAME: + request.locality = a.value + elif a.oid == x509.NameOID.STATE_OR_PROVINCE_NAME: + request.province = a.value + elif a.oid == x509.NameOID.ORGANIZATION_NAME: + request.organization = a.value + elif a.oid == x509.NameOID.ORGANIZATIONAL_UNIT_NAME: + request.organizational_unit = a.value + for e in cert.extensions: + if e.oid == x509.OID_SUBJECT_ALTERNATIVE_NAME: + request.san_dns = list([x.value for x in e.value if isinstance(x, x509.DNSName)]) + request.email_addresses = list([x.value for x in e.value if isinstance(x, x509.RFC822Name)]) + request.ip_addresses = list([x.value.exploded for x in e.value if isinstance(x, x509.IPAddress)]) + if cert.signature_algorithm_oid in (algos.ECDSA_WITH_SHA1, algos.ECDSA_WITH_SHA224, algos.ECDSA_WITH_SHA256, + algos.ECDSA_WITH_SHA384, algos.ECDSA_WITH_SHA512): + request.key_type = (KeyType.ECDSA, KeyType.ALLOWED_CURVES[0]) + else: + request.key_type = KeyType(KeyType.RSA, 2048) # todo: make parsing key size + if not request.csr: + request.build_csr() + status, data = self._post(URLS.CERTIFICATE_RENEW, + data={"CertificateDN": request.id, "PKCS10": request.csr}) + if status == HTTPStatus.OK: + if "CertificateDN" in data: + request.id = data['CertificateDN'] + log.debug("Certificate successfully requested with request id %s." % request.id) + return True + + log.error("Request status is not %s. %s." % HTTPStatus.OK, status) + raise CertificateRequestError + + @staticmethod + def _parse_zone_config_to_policy(data): + # todo: parse over values to regexps (dont forget tests!) + p = data["Policy"] + if p["KeyPair"]["KeyAlgorithm"]["Locked"]: + if p["KeyPair"]["KeyAlgorithm"]["Value"] == "RSA": + if p["KeyPair"]["KeySize"]["Locked"]: + key_types = [KeyType(KeyType.RSA, p["KeyPair"]["KeySize"]["Value"])] + else: + key_types = [KeyType(KeyType.RSA, x) for x in KeyType.ALLOWED_SIZES] + elif p["KeyPair"]["KeyAlgorithm"]["Value"] == "ECC": + if p["KeyPair"]["EllipticCurve"]["Locked"]: + key_types = [KeyType(KeyType.ECDSA, p["KeyPair"]["EllipticCurve"]["Value"])] + else: + key_types = [KeyType(KeyType.ECDSA, x) for x in KeyType.ALLOWED_CURVES] + else: + raise ServerUnexptedBehavior + else: + key_types = [] + if p["KeyPair"].get("KeySize", {}).get("Locked"): + key_types += [KeyType(KeyType.RSA, p["KeyPair"]["KeySize"]["Value"])] + else: + key_types += [KeyType(KeyType.RSA, x) for x in KeyType.ALLOWED_SIZES] + if p["KeyPair"].get("EllipticCurve", {}).get("Locked"): + key_types += [KeyType(KeyType.ECDSA, p["KeyPair"]["EllipticCurve"]["Value"])] + else: + key_types += [KeyType(KeyType.ECDSA, x) for x in KeyType.ALLOWED_CURVES] + return Policy(key_types=key_types) + + @staticmethod + def _parse_zone_data_to_object(data): + s = data["Policy"]["Subject"] + ou = s['OrganizationalUnit'].get('Values') + policy = TPPTokenConnection._parse_zone_config_to_policy(data) + if data["Policy"]["KeyPair"]["KeyAlgorithm"]["Value"] == "RSA": + key_type = KeyType(KeyType.RSA, data["Policy"]["KeyPair"]["KeySize"]["Value"]) + elif data["Policy"]["KeyPair"]["KeyAlgorithm"]["Value"] == "ECC": + key_type = KeyType(KeyType.ECDSA, data["Policy"]["KeyPair"]["EllipticCurve"]["Value"]) + else: + key_type = None + z = ZoneConfig( + organization=CertField(s['Organization']['Value'], locked=s['Organization']['Locked']), + organizational_unit=CertField(ou, locked=s['OrganizationalUnit']['Locked']), + country=CertField(s['Country']['Value'], locked=s['Country']['Locked']), + province=CertField(s['State']['Value'], locked=s['State']['Locked']), + locality=CertField(s['City']['Value'], locked=s['City']['Locked']), + policy=policy, + key_type=key_type, + ) + return z + + def read_zone_conf(self, tag): + status, data = self._post(URLS.ZONE_CONFIG, {"PolicyDN": self._get_policy_dn(tag)}) + if status != HTTPStatus.OK: + raise ServerUnexptedBehavior("Server returns %d status on reading zone configuration." % status) + return self._parse_zone_data_to_object(data) + + @staticmethod + def _get_policy_dn(zone): + if zone is None: + log.error("Bad zone: %s" % zone) + raise ClientBadData + if re.match(r"^\\\\VED\\\\Policy", zone): + return zone + else: + if re.match(r"^\\\\", zone): + return r"\\VED\\Policy" + zone + else: + return r"\\VED\\Policy\\" + zone + + def search_by_thumbprint(self, thumbprint): + """ + :param str thumbprint: + """ + thumbprint = re.sub(r'[^\dabcdefABCDEF]', "", thumbprint) + thumbprint = thumbprint.upper() + status, data = self._get(URLS.CERTIFICATE_SEARCH, params={"Thumbprint": thumbprint}) + if status != HTTPStatus.OK: + raise ServerUnexptedBehavior + + if not data['Certificates']: + raise ClientBadData("Certificate not found by thumbprint") + return data['Certificates'][0]['DN'] + + def _read_config_dn(self, dn, attribute_name): + status, data = self._post(URLS.CONFIG_READ_DN, { + "ObjectDN": dn, + "AttributeName": attribute_name, + }) + if status != HTTPStatus.OK: + raise ServerUnexptedBehavior("") + return data + def get_access_token(self, authentication=None): """ Obtains an access token to be used for subsequent api operations. @@ -142,12 +381,12 @@ def get_access_token(self, authentication=None): "scope": self._auth.scope, "state": "", } - status, resp_data = self._token_post(self.urls[PATH_AUTHORIZE_TOKEN], request_data) + status, resp_data = self._post(URLS.AUTHORIZE_TOKEN, request_data, False, False) if status != HTTPStatus.OK: raise ServerUnexptedBehavior("Server returns %d status on retrieving access token." % status) token_info = self._parse_access_token_data_to_object(resp_data) - self.update_auth(token_info) + self._update_auth(token_info) return token_info def refresh_access_token(self): @@ -155,30 +394,21 @@ def refresh_access_token(self): "refresh_token": self._auth.refresh_token, "client_id": self._auth.client_id, } - status, resp_data = self._token_post(self.urls[PATH_REFRESH_TOKEN], request_data) + status, resp_data = self._post(URLS.REFRESH_TOKEN, request_data, False, False) if status != HTTPStatus.OK: raise ServerUnexptedBehavior("Server returns %d status on refreshing access token" % status) token_info = self._parse_access_token_data_to_object(resp_data) - self.update_auth(token_info) + self._update_auth(token_info) return token_info def revoke_access_token(self): - status, resp_data = self._get(url=self.urls[PATH_REVOKE_TOKEN]) + status, resp_data = self._get(url=URLS.REVOKE_TOKEN, params=None, check_token=False) if status != HTTPStatus.OK: raise ServerUnexptedBehavior("Server returns %d status on revoking access token" % status) return status, resp_data - def _token_post(self, url, data=None): - if isinstance(data, dict): - tpp_url = self._base_url - response = requests.post(tpp_url + url, json=data, **self._http_request_kwargs) - else: - log.error("Unexpected client data type: %s for %s" % (type(data), url)) - raise ClientBadData - return self.process_server_response(response) - - def update_auth(self, token_info): + def _update_auth(self, token_info): if isinstance(token_info, TokenInfo): self._auth.access_token = token_info.access_token self._auth.refresh_token = token_info.refresh_token @@ -204,6 +434,3 @@ def _parse_access_token_data_to_object(data): token_type=data["token_type"] ) return token_info - - def _read_config_dn(self, dn, attribute_name): - pass