diff --git a/anchor/X509/certificate.py b/anchor/X509/certificate.py index 52a9f01..0866571 100644 --- a/anchor/X509/certificate.py +++ b/anchor/X509/certificate.py @@ -13,127 +13,108 @@ from __future__ import absolute_import -from cryptography.hazmat.backends.openssl import backend +import base64 +import binascii +import io + +from cryptography.hazmat import backends as cio_backends +from cryptography.hazmat.primitives.asymmetric import dsa +from cryptography.hazmat.primitives.asymmetric import padding +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives import hashes +from pyasn1.codec.der import decoder +from pyasn1.codec.der import encoder +from pyasn1.type import univ as asn1_univ +from pyasn1_modules import pem +from pyasn1_modules import rfc2459 # X509v3 from anchor.X509 import errors -from anchor.X509 import message_digest +from anchor.X509 import extension from anchor.X509 import name from anchor.X509 import utils +SIGNING_ALGORITHMS = { + ('RSA', 'MD5'): rfc2459.md5WithRSAEncryption, + ('RSA', 'SHA1'): rfc2459.sha1WithRSAEncryption, + ('RSA', 'SHA224'): asn1_univ.ObjectIdentifier('1.2.840.113549.1.1.14'), + ('RSA', 'SHA256'): asn1_univ.ObjectIdentifier('1.2.840.113549.1.1.11'), + ('RSA', 'SHA384'): asn1_univ.ObjectIdentifier('1.2.840.113549.1.1.12'), + ('RSA', 'SHA512'): asn1_univ.ObjectIdentifier('1.2.840.113549.1.1.13'), + ('DSA', 'SHA1'): rfc2459.id_dsa_with_sha1, + ('DSA', 'SHA224'): asn1_univ.ObjectIdentifier('2.16.840.1.101.3.4.3.1'), + ('DSA', 'SHA256'): asn1_univ.ObjectIdentifier('2.16.840.1.101.3.4.3.2'), +} + + class X509CertificateError(errors.X509Error): """Specific error for X509 certificate operations.""" - def __init__(self, what): - super(X509CertificateError, self).__init__(what) - - -class X509Extension(object): - """An X509 V3 Certificate extension.""" - def __init__(self, ext): - self._lib = backend._lib - self._ffi = backend._ffi - self._ext = ext - - def __str__(self): - return "%s %s" % (self.get_name(), self.get_value()) - - def get_name(self): - """Get the extension name as a python string.""" - ext_obj = self._lib.X509_EXTENSION_get_object(self._ext) - ext_nid = self._lib.OBJ_obj2nid(ext_obj) - ext_name_str = self._lib.OBJ_nid2sn(ext_nid) - return self._ffi.string(ext_name_str).decode('ascii') - - def get_value(self): - """Get the extension value as a python string.""" - bio = self._lib.BIO_new(self._lib.BIO_s_mem()) - bio = self._ffi.gc(bio, self._lib.BIO_free) - self._lib.X509V3_EXT_print(bio, self._ext, 0, 0) - size = 1024 - data = self._ffi.new("char[]", size) - self._lib.BIO_gets(bio, data, size) - return self._ffi.string(data).decode('ascii') + pass class X509Certificate(object): """X509 certificate class.""" - def __init__(self): - self._lib = backend._lib - self._ffi = backend._ffi - certObj = self._lib.X509_new() - if certObj == self._ffi.NULL: - raise X509CertificateError("Could not create X509 certificate " - "object") # pragma: no cover + def __init__(self, certificate=None): + if certificate is None: + self._cert = rfc2459.Certificate() + self._cert['tbsCertificate'] = rfc2459.TBSCertificate() + else: + self._cert = certificate - self._certObj = certObj + @staticmethod + def from_open_file(f): + try: + der_content = pem.readPemFromFile(f) + certificate = decoder.decode(der_content, + asn1Spec=rfc2459.Certificate())[0] + return X509Certificate(certificate) + except Exception: + raise X509CertificateError("Could not read X509 certificate from " + "PEM data.") - def __del__(self): - if getattr(self, '_certObj', None): - self._lib.X509_free(self._certObj) - - def from_buffer(self, data): + @staticmethod + def from_buffer(data): """Build this X509 object from a data buffer in memory. :param data: A data buffer """ - if type(data) != bytes: - data = data.encode('ascii') - bio = backend._bytes_to_bio(data) + return X509Certificate.from_open_file(io.StringIO(data)) - # NOTE(tkelsey): some versions of OpenSSL dont re-use the cert object - # properly, so free it and use the new one - # - certObj = self._lib.PEM_read_bio_X509(bio[0], - self._ffi.NULL, - self._ffi.NULL, - self._ffi.NULL) - if certObj == self._ffi.NULL: - raise X509CertificateError("Could not read X509 certificate from " - "PEM data.") - - self._lib.X509_free(self._certObj) - self._certObj = certObj - - def from_file(self, path): + @staticmethod + def from_file(path): """Build this X509 certificate object from a data file on disk. :param path: A data buffer """ - data = None - with open(path, 'rb') as f: - data = f.read() - self.from_buffer(data) + with open(path, 'r') as f: + return X509Certificate.from_open_file(f) def as_pem(self): """Serialise this X509 certificate object as PEM string.""" - raw_bio = self._lib.BIO_new(self._lib.BIO_s_mem()) - bio = self._ffi.gc(raw_bio, self._lib.BIO_free) - ret = self._lib.PEM_write_bio_X509(bio, self._certObj) - - if ret == 0: - raise X509CertificateError("Could not write X509 certificate " - "as PEM data.") # pragma: no cover - - buf = self._ffi.new("char**") - pem_len = self._lib.BIO_get_mem_data(bio, buf) - pem = self._ffi.string(buf[0], pem_len) - - return pem + header = '-----BEGIN CERTIFICATE-----' + footer = '-----END CERTIFICATE-----' + der_cert = encoder.encode(self._cert) + b64_encoder = (base64.encodestring if str is bytes else + base64.encodebytes) + b64_cert = b64_encoder(der_cert).decode('ascii') + return "%s\n%s%s\n" % (header, b64_cert, footer) def set_version(self, v): """Set the version of this X509 certificate object. :param v: The version """ - ret = self._lib.X509_set_version(self._certObj, v) - if ret == 0: - raise X509CertificateError("Could not set X509 certificate " - "version.") # pragma: no cover + self._cert['tbsCertificate']['version'] = v def get_version(self): """Get the version of this X509 certificate object.""" - return self._lib.X509_get_version(self._certObj) + return self._cert['tbsCertificate']['version'] + + def get_validity(self): + if self._cert['tbsCertificate']['validity'] is None: + self._cert['tbsCertificate']['validity'] = None + return self._cert['tbsCertificate']['validity'] def set_not_before(self, t): """Set the 'not before' date field. @@ -141,15 +122,13 @@ class X509Certificate(object): :param t: time in seconds since the epoch """ asn1_time = utils.timestamp_to_asn1_time(t) - ret = self._lib.X509_set_notBefore(self._certObj, asn1_time) - self._lib.ASN1_TIME_free(asn1_time) - if ret == 0: - raise X509CertificateError("Could not set X509 certificate " - "not before time.") # pragma: no cover + validity = self.get_validity() + validity['notBefore'] = asn1_time def get_not_before(self): """Get the 'not before' date field as seconds since the epoch.""" - not_before = self._lib.X509_get_notBefore(self._certObj) + validity = self.get_validity() + not_before = validity['notBefore'] return utils.asn1_time_to_timestamp(not_before) def set_not_after(self, t): @@ -158,37 +137,28 @@ class X509Certificate(object): :param t: time in seconds since the epoch """ asn1_time = utils.timestamp_to_asn1_time(t) - ret = self._lib.X509_set_notAfter(self._certObj, asn1_time) - self._lib.ASN1_TIME_free(asn1_time) - if ret == 0: - raise X509CertificateError("Could not set X509 certificate " - "not after time.") # pragma: no cover + validity = self.get_validity() + validity['notAfter'] = asn1_time def get_not_after(self): """Get the 'not after' date field as seconds since the epoch.""" - not_after = self._lib.X509_get_notAfter(self._certObj) + validity = self.get_validity() + not_after = validity['notAfter'] return utils.asn1_time_to_timestamp(not_after) def set_pubkey(self, pkey): """Set the public key field. - :param pkey: The public key, an EVP_PKEY ssl type + :param pkey: The public key, rfc2459.SubjectPublicKeyInfo description """ - ret = self._lib.X509_set_pubkey(self._certObj, pkey) - if ret == 0: - raise X509CertificateError("Could not set X509 certificate " - "pubkey.") # pragma: no cover + self._cert['tbsCertificate']['subjectPublicKeyInfo'] = pkey def get_subject(self): """Get the subject name field value. :return: An X509Name object instance """ - val = self._lib.X509_get_subject_name(self._certObj) - if val == self._ffi.NULL: - raise X509CertificateError("Could not get subject from X509 " - "certificate.") # pragma: no cover - + val = self._cert['tbsCertificate']['subject'][0] return name.X509Name(val) def set_subject(self, subject): @@ -197,10 +167,9 @@ class X509Certificate(object): :param subject: An X509Name object instance """ val = subject._name_obj - ret = self._lib.X509_set_subject_name(self._certObj, val) - if ret == 0: - raise X509CertificateError("Could not set X509 certificate " - "subject.") # pragma: no cover + if self._cert['tbsCertificate']['subject'] is None: + self._cert['tbsCertificate']['subject'] = rfc2459.Name() + self._cert['tbsCertificate']['subject'][0] = val def set_issuer(self, issuer): """Set the issuer name field value. @@ -208,20 +177,16 @@ class X509Certificate(object): :param issuer: An X509Name object instance """ val = issuer._name_obj - ret = self._lib.X509_set_issuer_name(self._certObj, val) - if ret == 0: - raise X509CertificateError("Could not set X509 certificate " - "issuer.") # pragma: no cover + if self._cert['tbsCertificate']['issuer'] is None: + self._cert['tbsCertificate']['issuer'] = rfc2459.Name() + self._cert['tbsCertificate']['issuer'][0] = val def get_issuer(self): """Get the issuer name field value. :return: An X509Name object instance """ - val = self._lib.X509_get_issuer_name(self._certObj) - if val == self._ffi.NULL: - raise X509CertificateError("Could not get subject from X509 " - "certificate.") # pragma: no cover + val = self._cert['tbsCertificate']['issuer'][0] return name.X509Name(val) def set_serial_number(self, serial): @@ -232,14 +197,18 @@ class X509Certificate(object): :param serial: The serial number, 32 bit integer """ - asn1_int = self._lib.ASN1_INTEGER_new() - ret = self._lib.ASN1_INTEGER_set(asn1_int, serial) - if ret != 0: - ret = self._lib.X509_set_serialNumber(self._certObj, asn1_int) - self._lib.ASN1_INTEGER_free(asn1_int) - if ret == 0: - raise X509CertificateError("Could not set X509 certificate " - "serial number.") # pragma: no cover + self._cert['tbsCertificate']['serialNumber'] = serial + + def _get_extensions(self): + if self._cert['tbsCertificate']['extensions'] is None: + # this actually initialises the extensions tag rather than + # assign None + self._cert['tbsCertificate']['extensions'] = None + return self._cert['tbsCertificate']['extensions'] + + def get_extensions(self): + extensions = self._get_extensions() + return [extension.construct_extension(e) for e in extensions] def add_extension(self, ext, index): """Add an X509 V3 Certificate extension. @@ -247,10 +216,11 @@ class X509Certificate(object): :param ext: An X509Extension instance :param index: The index of the extension """ - ret = self._lib.X509_add_ext(self._certObj, ext._ext, index) - if ret == 0: - raise X509CertificateError("Could not add X509 certificate " - "extension.") # pragma: no cover + if not isinstance(ext, extension.X509Extension): + raise errors.X509Error("ext needs to be a pyasn1 extension") + + extensions = self._get_extensions() + extensions[index] = ext.as_asn1() def sign(self, key, md='sha1'): """Sign the X509 certificate with a key using a message digest algorithm @@ -262,28 +232,44 @@ class X509Certificate(object): - sha1 - sha256 """ - mda = getattr(self._lib, "EVP_%s" % md, None) - if mda is None: - msg = 'X509 signing error: Unknown algorithm {a}'.format(a=md) - raise X509CertificateError(msg) - ret = self._lib.X509_sign(self._certObj, key, mda()) - if ret == 0: - raise X509CertificateError("X509 signing error: Could not sign " - " certificate.") # pragma: no cover + md = md.upper() + + if isinstance(key, rsa.RSAPrivateKey): + encryption = 'RSA' + elif isinstance(key, dsa.DSAPrivateKey): + encryption = 'DSA' + else: + raise errors.X509Error("Unknown key type: %s" % (key.__class__,)) + + hash_class = utils.get_hash_class(md) + signature_type = SIGNING_ALGORITHMS.get((encryption, md)) + if signature_type is None: + raise errors.X509Error( + "Unknown encryption/hash combination %s/%s" % (encryption, md)) + + algo_id = rfc2459.AlgorithmIdentifier() + algo_id['algorithm'] = signature_type + if encryption == 'RSA': + algo_id['parameters'] = encoder.encode(asn1_univ.Null()) + elif encryption == 'DSA': + pass # parameters should be omitted, see RFC3279 + + self._cert['tbsCertificate']['signature'] = algo_id + + to_sign = encoder.encode(self._cert['tbsCertificate']) + if encryption == 'RSA': + signer = key.signer(padding.PKCS1v15(), hash_class()) + elif encryption == 'DSA': + signer = key.signer(hash_class()) + signer.update(to_sign) + signature = signer.finalize() + self._cert['signatureValue'] = "'%s'B" % ( + utils.bytes_to_bin(signature),) + self._cert['signatureAlgorithm'] = algo_id def as_der(self): """Return this X509 certificate as DER encoded data.""" - buf = None - num = self._lib.i2d_X509(self._certObj, self._ffi.NULL) - if num != 0: - buf = self._ffi.new("unsigned char[]", num + 1) - buf_ptr = self._ffi.new("unsigned char**") - buf_ptr[0] = buf - num = self._lib.i2d_X509(self._certObj, buf_ptr) - else: - raise X509CertificateError("Could not encode X509 certificate " - "as DER.") # pragma: no cover - return buf + return encoder.encode(self._cert) def get_fingerprint(self, md='md5'): """Get the fingerprint of this X509 certificate. @@ -291,7 +277,11 @@ class X509Certificate(object): :param md: The message digest algorthim used to compute the fingerprint :return: The fingerprint encoded as a hex string """ - der = self.as_der() - md = message_digest.MessageDigest(md) - md.update(der) - return md.final() + hash_class = utils.get_hash_class(md) + if hash_class is None: + raise errors.X509Error( + "Unknown hash %s" % (md,)) + hasher = hashes.Hash(hash_class(), + backend=cio_backends.default_backend()) + hasher.update(self.as_der()) + return binascii.hexlify(hasher.finalize()).upper().decode('ascii') diff --git a/anchor/X509/extension.py b/anchor/X509/extension.py new file mode 100644 index 0000000..66a9859 --- /dev/null +++ b/anchor/X509/extension.py @@ -0,0 +1,318 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import + +import functools + +import netaddr +from pyasn1.codec.der import decoder +from pyasn1.codec.der import encoder +from pyasn1.type import constraint as asn1_constraint +from pyasn1.type import namedtype as asn1_namedtype +from pyasn1.type import univ as asn1_univ +from pyasn1_modules import rfc2459 # X509v3 + +from anchor.X509 import errors +from anchor.X509 import utils + + +EXTENSION_NAMES = { + rfc2459.id_ce_policyConstraints: 'policyConstraints', + rfc2459.id_ce_basicConstraints: 'basicConstraints', + rfc2459.id_ce_subjectDirectoryAttributes: 'subjectDirectoryAttributes', + rfc2459.id_ce_deltaCRLIndicator: 'deltaCRLIndicator', + rfc2459.id_ce_cRLDistributionPoints: 'cRLDistributionPoints', + rfc2459.id_ce_issuingDistributionPoint: 'issuingDistributionPoint', + rfc2459.id_ce_nameConstraints: 'nameConstraints', + rfc2459.id_ce_certificatePolicies: 'certificatePolicies', + rfc2459.id_ce_policyMappings: 'policyMappings', + rfc2459.id_ce_privateKeyUsagePeriod: 'privateKeyUsagePeriod', + rfc2459.id_ce_keyUsage: 'keyUsage', + rfc2459.id_ce_authorityKeyIdentifier: 'authorityKeyIdentifier', + rfc2459.id_ce_subjectKeyIdentifier: 'subjectKeyIdentifier', + rfc2459.id_ce_certificateIssuer: 'certificateIssuer', + rfc2459.id_ce_subjectAltName: 'subjectAltName', + rfc2459.id_ce_issuerAltName: 'issuerAltName', +} + + +LONG_KEY_USAGE_NAMES = { + "Digital Signature": "digitalSignature", + "Non Repudiation": "nonRepudiation", + "Key Encipherment": "keyEncipherment", + "Data Encipherment": "dataEncipherment", + "Key Agreement": "keyAgreement", + "Certificate Sign": "keyCertSign", + "CRL Sign": "cRLSign", + "Encipher Only": "encipherOnly", + "Decipher Only": "decipherOnly", +} + + +def uses_ext_value(f): + """Wrapper allowing reading of extension value. + + Because the value is normally saved in a (double) serialised way, it's + not easily accessible to the member methods. This is made easier by + unpacking the extension value into an extra argument. + """ + @functools.wraps(f) + def ext_value_filled(self, *args, **kwargs): + kwargs['ext_value'] = self._get_value() + return f(self, *args, **kwargs) + return ext_value_filled + + +def modifies_ext_value(f): + """Wrapper allowing modification of extension value. + + Because the value is normally saved in a (double) serialised way, it's + not easily accessible to the member methods. This is made easier by + unpacking the extension value into an extra argument. + New value needs to be returned from the method. + """ + @functools.wraps(f) + def ext_value_filled(self, *args, **kwargs): + value = self._get_value() + kwargs['ext_value'] = value + # since some elements like NamedValue are pure value types, there is + # no interface to modify them and new versions have to be returned + value = f(self, *args, **kwargs) + self._set_value(value) + return ext_value_filled + + +class BasicConstraints(asn1_univ.Sequence): + """Custom BasicConstraint implementation until pyasn1_modules is fixes.""" + componentType = asn1_namedtype.NamedTypes( + asn1_namedtype.DefaultedNamedType('cA', asn1_univ.Boolean(False)), + asn1_namedtype.OptionalNamedType( + 'pathLenConstraint', + asn1_univ.Integer().subtype( + subtypeSpec=asn1_constraint.ValueRangeConstraint(0, 64))) + ) + + +class X509Extension(object): + """Abstraction for the pyasn1 Extension structures. + + The object should normally be constructed using `construct_extension`, + which will choose the right extension type based on the id. + Each extension has an immutable oid and a spec of the internal value + representation. + Unknown extension types can be still represented by the + X509Extension object and copied/serialised without understanding the + value details. The value will not be displayed properly in the logs + in the case. + """ + _oid = None + spec = None + + """An X509 V3 Certificate extension.""" + def __init__(self, ext=None): + if ext is None: + if self.spec is None: + raise errors.X509Error("cannot create generic extension") + self._ext = rfc2459.Extension() + self._ext['extnID'] = self._oid + self._set_value(self._get_default_value()) + else: + if not isinstance(ext, rfc2459.Extension): + raise errors.X509Error("extension has incorrect type") + self._ext = ext + + @classmethod + def _get_default_value(cls): + # if there are any non-optional fields, this needs to be defined in + # the class + return cls.spec() + + def __str__(self): + return "%s: %s" % (self.get_name(), self.get_value_as_str()) + + def get_value_as_str(self): + return "" + + def get_oid(self): + return self._ext['extnID'] + + def get_name(self): + """Get the extension name as a python string.""" + oid = self.get_oid() + return EXTENSION_NAMES.get(oid, oid) + + def get_critical(self): + return self._ext['critical'] + + def set_critical(self, critical): + self._ext['critical'] = critical + + def _get_value(self): + value_der = decoder.decode(self._ext['extnValue'])[0] + return decoder.decode(value_der, asn1Spec=self.spec())[0] + + def _set_value(self, value): + if not isinstance(value, self.spec): + raise errors.X509Error("extension value has incorrect type") + self._ext['extnValue'] = encoder.encode(rfc2459.univ.OctetString( + encoder.encode(value))) + + def as_der(self): + return encoder.encode(self._ext) + + def as_asn1(self): + return self._ext + + +class X509ExtensionBasicConstraints(X509Extension): + spec = BasicConstraints + _oid = rfc2459.id_ce_basicConstraints + + @uses_ext_value + def get_ca(self, ext_value=None): + return bool(ext_value['cA']) + + @modifies_ext_value + def set_ca(self, ca, ext_value=None): + ext_value['cA'] = ca + return ext_value + + @uses_ext_value + def get_path_len_constraint(self, ext_value=None): + return ext_value['pathLenConstraint'] + + @modifies_ext_value + def set_path_len_constraint(self, length, ext_value=None): + ext_value['pathLenConstraint'] = length + return ext_value + + def __str__(self): + return "basicConstraints: CA: %s, pathLen: %s" % ( + str(self.get_ca()).upper(), self.get_path_len_constraint()) + + +class X509ExtensionKeyUsage(X509Extension): + spec = rfc2459.KeyUsage + _oid = rfc2459.id_ce_keyUsage + + fields = dict(spec.namedValues.namedValues) + inv_fields = dict((v, k) for k, v in spec.namedValues.namedValues) + + @classmethod + def _get_default_value(cls): + # if there are any non-optional fields, this needs to be defined in + # the class + return cls.spec("''B") + + @uses_ext_value + def get_usage(self, usage, ext_value=None): + usage = LONG_KEY_USAGE_NAMES.get(usage, usage) + pos = self.fields[usage] + if pos >= len(ext_value): + return False + return bool(ext_value[pos]) + + @uses_ext_value + def get_all_usages(self, ext_value=None): + return [self.inv_fields[i] for i, enabled in enumerate(ext_value) + if enabled] + + @modifies_ext_value + def set_usage(self, usage, state, ext_value=None): + usage = LONG_KEY_USAGE_NAMES.get(usage, usage) + pos = self.fields[usage] + values = [x for x in ext_value] + + if state: + while pos >= len(values): + values.append(0) + values[pos] = 1 + else: + if pos < len(values): + values[pos] = 0 + + bits = ''.join(str(x) for x in values) + return self.spec("'%s'B" % bits) + + def __str__(self): + return "keyUsage: " + ", ".join(self.get_all_usages()) + + +class X509ExtensionSubjectAltName(X509Extension): + spec = rfc2459.SubjectAltName + _oid = rfc2459.id_ce_subjectAltName + + @uses_ext_value + def get_dns_ids(self, ext_value=None): + dns_ids = [] + for name in ext_value: + if name.getName() != 'dNSName': + continue + component = name.getComponent() + dns_id = component.asOctets().decode(component.encoding) + dns_ids.append(dns_id) + return dns_ids + + @uses_ext_value + def get_ips(self, ext_value=None): + ips = [] + for name in ext_value: + if name.getName() != 'iPAddress': + continue + ips.append(utils.asn1_to_netaddr(name.getComponent())) + return ips + + @modifies_ext_value + def add_dns_id(self, dns_id, ext_value=None): + # TODO(stan) validate dns_id + new_pos = len(ext_value) + ext_value[new_pos] = None + ext_value[new_pos]['dNSName'] = dns_id + return ext_value + + @modifies_ext_value + def add_ip(self, ip, ext_value=None): + if not isinstance(ip, netaddr.IPAddress): + raise errors.X509Error("not a real ip address provided") + new_pos = len(ext_value) + ext_value[new_pos] = None + ext_value[new_pos]['iPAddress'] = utils.netaddr_to_asn1(ip) + return ext_value + + @uses_ext_value + def __str__(self, ext_value=None): + entries = ["DNS:%s" % (x,) for x in self.get_dns_ids()] + entries += ["IP:%s" % (x,) for x in self.get_ips()] + return "subjectAltName: " + ", ".join(entries) + + +EXTENSION_CLASSES = { + rfc2459.id_ce_basicConstraints: X509ExtensionBasicConstraints, + rfc2459.id_ce_keyUsage: X509ExtensionKeyUsage, + rfc2459.id_ce_subjectAltName: X509ExtensionSubjectAltName, +} + + +def construct_extension(ext): + """Construct an extension object of the right type. + + While X509Extension can provide basic access to the extension elements, + it cannot parse details of extensions. This function detects which type + should be used based on the extension id. + If the type is unknown, generic X509Extension is used instead. + """ + if not isinstance(ext, rfc2459.Extension): + raise errors.X509Error("extension has incorrect type") + ext_class = EXTENSION_CLASSES.get(ext['extnID'], X509Extension) + return ext_class(ext) diff --git a/anchor/X509/message_digest.py b/anchor/X509/message_digest.py deleted file mode 100644 index 53a004c..0000000 --- a/anchor/X509/message_digest.py +++ /dev/null @@ -1,93 +0,0 @@ -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -from __future__ import absolute_import - -from cryptography.hazmat.backends.openssl import backend - -import binascii - - -class MessageDigestError(Exception): - def __init__(self, what): - super(MessageDigestError, self).__init__(what) - - -class MessageDigest(object): - """Compute a message digest from input data.""" - - @staticmethod - def getValidAlgorithms(): - """Get a list of available valid hash algorithms.""" - algs = [ - "md5", - "ripemd160", - "sha224", - "sha256", - "sha384", - "sha512" - ] - ret = [] - for alg in algs: - if getattr(backend._lib, "EVP_%s" % alg, None) is not None: - ret.append(alg) - return ret - - def __init__(self, algo): - self._lib = backend._lib - self._ffi = backend._ffi - md = getattr(self._lib, "EVP_%s" % algo, None) - if md is None: - msg = 'MessageDigest error: unknown algorithm {a}'.format(a=algo) - raise MessageDigestError(msg) - - ret = 0 - ctx = self._lib.EVP_MD_CTX_create() - if ctx != self._ffi.NULL: - self.ctx = ctx - self.mda = md() - ret = self._lib.EVP_DigestInit_ex(self.ctx, - self.mda, - self._ffi.NULL) - - if ret == 0: - raise MessageDigestError( - "Could not setup message digest context.") # pragma: no cover - - def __del__(self): - if getattr(self, 'ctx', None): - self._lib.EVP_MD_CTX_cleanup(self.ctx) - self._lib.EVP_MD_CTX_destroy(self.ctx) - - def update(self, data): - """Add more data to the digest.""" - - ret = self._lib.EVP_DigestUpdate(self.ctx, data, len(data)) - if ret == 0: - raise MessageDigestError( - "Failed to update message digest data.") # pragma: no cover - - def final(self): - """get the final resulting digest value. - - Note that you should not call update() with additional data after using - final. - """ - sz = self._lib.EVP_MD_size(self.mda) - data = self._ffi.new("char[]", sz) - ret = self._lib.EVP_DigestFinal_ex(self.ctx, data, self._ffi.NULL) - if ret == 0: - raise MessageDigestError( - "Failed to get message digest.") # pragma: no cover - digest = self._ffi.string(data) - return binascii.hexlify(digest).decode('ascii').upper() diff --git a/anchor/X509/name.py b/anchor/X509/name.py index 2982989..4d6a123 100644 --- a/anchor/X509/name.py +++ b/anchor/X509/name.py @@ -13,21 +13,64 @@ from __future__ import absolute_import -from cryptography.hazmat.backends.openssl import backend +from pyasn1.codec.der import decoder +from pyasn1.codec.der import encoder +from pyasn1.type import error as asn1_error +from pyasn1.type import univ as asn1_univ +from pyasn1_modules import rfc2459 from anchor.X509 import errors -from anchor.X509 import utils +OID_commonName = rfc2459.id_at_commonName +OID_localityName = rfc2459.id_at_localityName +OID_stateOrProvinceName = rfc2459.id_at_stateOrProvinceName +OID_organizationName = rfc2459.id_at_organizationName +OID_organizationalUnitName = rfc2459.id_at_organizationalUnitName +OID_countryName = rfc2459.id_at_countryName +OID_pkcs9_emailAddress = rfc2459.emailAddress +OID_surname = rfc2459.id_at_sutname +OID_givenName = rfc2459.id_at_givenName -NID_countryName = backend._lib.NID_countryName -NID_stateOrProvinceName = backend._lib.NID_stateOrProvinceName -NID_localityName = backend._lib.NID_localityName -NID_organizationName = backend._lib.NID_organizationName -NID_organizationalUnitName = backend._lib.NID_organizationalUnitName -NID_commonName = backend._lib.NID_commonName -NID_pkcs9_emailAddress = backend._lib.NID_pkcs9_emailAddress -NID_surname = backend._lib.NID_surname -NID_givenName = backend._lib.NID_givenName +name_oids = { + rfc2459.id_at_name: rfc2459.X520name, + rfc2459.id_at_sutname: rfc2459.X520name, + rfc2459.id_at_givenName: rfc2459.X520name, + rfc2459.id_at_initials: rfc2459.X520name, + rfc2459.id_at_generationQualifier: rfc2459.X520name, + rfc2459.id_at_commonName: rfc2459.X520CommonName, + rfc2459.id_at_localityName: rfc2459.X520LocalityName, + rfc2459.id_at_stateOrProvinceName: rfc2459.X520StateOrProvinceName, + rfc2459.id_at_organizationName: rfc2459.X520OrganizationName, + rfc2459.id_at_organizationalUnitName: rfc2459.X520OrganizationalUnitName, + rfc2459.id_at_title: rfc2459.X520Title, + rfc2459.id_at_dnQualifier: rfc2459.X520dnQualifier, + rfc2459.id_at_countryName: rfc2459.X520countryName, + rfc2459.emailAddress: rfc2459.Pkcs9email, +} + +code_names = { + rfc2459.id_at_commonName: "CN", + rfc2459.id_at_localityName: "L", + rfc2459.id_at_stateOrProvinceName: "ST", + rfc2459.id_at_organizationName: "O", + rfc2459.id_at_organizationalUnitName: "OU", + rfc2459.id_at_countryName: "C", + rfc2459.id_at_givenName: "GN", + rfc2459.id_at_sutname: "SN", + rfc2459.emailAddress: "emailAddress", +} + +short_names = { + rfc2459.id_at_commonName: "commonName", + rfc2459.id_at_localityName: "localityName", + rfc2459.id_at_stateOrProvinceName: "stateOrProvinceName", + rfc2459.id_at_organizationName: "organizationName", + rfc2459.id_at_organizationalUnitName: "organizationalUnitName", + rfc2459.id_at_countryName: "countryName", + rfc2459.id_at_givenName: "givenName", + rfc2459.id_at_sutname: "surname", + rfc2459.emailAddress: "emailAddress", +} class X509Name(object): @@ -35,104 +78,90 @@ class X509Name(object): class Entry(): """An X509 Name sub-entry object.""" - def __init__(self, obj, parent): - self._parent = parent - self._lib = backend._lib - self._ffi = backend._ffi - self._entry = obj + def __init__(self, obj): + self._obj = obj def __str__(self): return "%s: %s" % (self.get_name(), self.get_value()) + def get_oid(self): + return self._obj[0]['type'] + def get_name(self): """Get the name of this entry. :return: entry name as a python string """ - asn1_obj = self._lib.X509_NAME_ENTRY_get_object(self._entry) - buf = self._ffi.new('char[]', 1024) - ret = self._lib.OBJ_obj2txt(buf, 1024, asn1_obj, 0) - if ret == 0: - raise errors.X509Error("Could not convert ASN1_OBJECT to " - "string.") # pragma: no cover - return self._ffi.string(buf).decode('ascii') + oid = self.get_oid() + return short_names.get(oid, str(oid)) + + def get_code(self): + """Get the name of this entry. + + :return: entry name as a python string + """ + oid = self.get_oid() + return code_names.get(oid, str(oid)) def get_value(self): """Get the value of this entry. :return: entry value as a python string """ - val = self._lib.X509_NAME_ENTRY_get_data(self._entry) - return utils.asn1_string_to_utf8(val) + value = self._obj[0]['value'] + der = value.asOctets() + name_spec = name_oids[self.get_oid()]() + + value = decoder.decode(der, asn1Spec=name_spec)[0] + if hasattr(value, 'getComponent'): + value = value.getComponent() + return value.asOctets().decode(value.encoding) def __init__(self, name_obj=None): - self._lib = backend._lib - self._ffi = backend._ffi if name_obj is not None: - self._name_obj = self._lib.X509_NAME_dup(name_obj) - if self._name_obj == self._ffi.NULL: - raise errors.X509Error("Failed to copy X509_NAME " - "object.") # pragma: no cover + if not isinstance(name_obj, rfc2459.RDNSequence): + raise TypeError("name is not an RDNSequence") + # TODO(stan): actual copy + self._name_obj = name_obj else: - self._name_obj = self._lib.X509_NAME_new() - if self._name_obj == self._ffi.NULL: - raise errors.X509Error("Failed to create " - "X509_NAME object.") # pragma: no cover - - def __del__(self): - self._lib.X509_NAME_free(self._name_obj) + self._name_obj = rfc2459.RDNSequence() def __str__(self): - # NOTE(tkelsey): we need to pass in a max size, so why not 1024 - val = self._lib.X509_NAME_oneline(self._name_obj, self._ffi.NULL, 1024) - if val == self._ffi.NULL: - raise errors.X509Error("Could not convert" - " X509_NAME to string.") # pragma: no cover - - val = self._ffi.gc(val, self._lib.OPENSSL_free) - return self._ffi.string(val).decode('ascii') + return '/' + '/'.join("%s=%s" % (e.get_code(), e.get_value()) + for e in self) def __len__(self): - return self._lib.X509_NAME_entry_count(self._name_obj) + return len(self._name_obj) def __getitem__(self, idx): - if not (0 <= idx < self.entry_count()): - raise IndexError("index out of range") - ent = self._lib.X509_NAME_get_entry(self._name_obj, idx) - return X509Name.Entry(ent, self) + return X509Name.Entry(self._name_obj[idx]) def __iter__(self): - for i in range(self.entry_count()): + for i in range(len(self)): yield self[i] - def add_name_entry(self, nid, text): - """Add a name entry by its NID name.""" - ret = self._lib.X509_NAME_add_entry_by_NID( - self._name_obj, nid, - self._lib.MBSTRING_UTF8, - text.encode('utf8'), -1, -1, 0) + def add_name_entry(self, oid, text): + if not isinstance(oid, asn1_univ.ObjectIdentifier): + raise errors.X509Error("oid '%s' is not valid" % (oid,)) + entry = rfc2459.RelativeDistinguishedName() + entry[0] = rfc2459.AttributeTypeAndValue() + entry[0]['type'] = oid + name_type = name_oids[oid] + try: + if name_type in (rfc2459.X520countryName, rfc2459.Pkcs9email): + val = name_type(text) + else: + val = name_type() + val['utf8String'] = text + except asn1_error.ValueConstraintError: + raise errors.X509Error("Name '%s' is not valid" % text) + entry[0]['value'] = rfc2459.AttributeValue(encoder.encode(val)) + self._name_obj[len(self)] = entry - if ret != 1: - raise errors.X509Error("Failed to add name entry: '%s' '%s'" % ( - nid, text)) - - def entry_count(self): - """Get the number of entries in the name object.""" - return self._lib.X509_NAME_entry_count(self._name_obj) - - def get_entries_by_nid(self, nid): + def get_entries_by_oid(self, oid): """Get a name entry corresponding to an NID name. :param nid: an NID for the new name entry :return: An X509Name.Entry object """ - out = [] - idx = self._lib.X509_NAME_get_index_by_NID(self._name_obj, nid, -1) - while idx != -1: - val = self._lib.X509_NAME_get_entry(self._name_obj, idx) - if val != self._ffi.NULL: - out.append(X509Name.Entry(val, self)) - - idx = self._lib.X509_NAME_get_index_by_NID(self._name_obj, - nid, idx) - return out + return [entry for entry in self if entry.get_oid() == oid] diff --git a/anchor/X509/signing_request.py b/anchor/X509/signing_request.py index b031b9b..349b10b 100644 --- a/anchor/X509/signing_request.py +++ b/anchor/X509/signing_request.py @@ -13,13 +13,23 @@ from __future__ import absolute_import -from cryptography.hazmat.backends.openssl import backend +import io + +from pyasn1.codec.der import decoder +from pyasn1.codec.der import encoder +from pyasn1.type import univ as asn1_univ +from pyasn1_modules import pem +from pyasn1_modules import rfc2314 # PKCS#10 / CSR +from pyasn1_modules import rfc2459 # X509 -from anchor.X509 import certificate from anchor.X509 import errors +from anchor.X509 import extension from anchor.X509 import name +OID_extensionRequest = asn1_univ.ObjectIdentifier('1.2.840.113549.1.9.14') + + class X509CsrError(errors.X509Error): def __init__(self, what): super(X509CsrError, self).__init__(what) @@ -27,84 +37,108 @@ class X509CsrError(errors.X509Error): class X509Csr(object): """An X509 Certificate Signing Request.""" - def __init__(self): - self._lib = backend._lib - self._ffi = backend._ffi - csrObj = self._lib.X509_REQ_new() - if csrObj == self._ffi.NULL: - raise X509CsrError( - "Could not create X509 CSR Object.") # pragma: no cover + def __init__(self, csr=None): + if csr is None: + self._csr = rfc2314.CertificationRequest() + else: + self._csr = csr - self._csrObj = csrObj + @staticmethod + def from_open_file(f): + try: + der_content = pem.readPemFromFile( + f, startMarker='-----BEGIN CERTIFICATE REQUEST-----', + endMarker='-----END CERTIFICATE REQUEST-----') + csr = decoder.decode(der_content, + asn1Spec=rfc2314.CertificationRequest())[0] + return X509Csr(csr) + except Exception: + raise X509CsrError("Could not read X509 certificate from " + "PEM data.") - def __del__(self): - if getattr(self, '_csrObj', None): - self._lib.X509_REQ_free(self._csrObj) - - def from_buffer(self, data, password=None): + @staticmethod + def from_buffer(data): """Create this CSR from a buffer :param data: The data buffer - :param password: decryption password, if needed """ - if type(data) != bytes: - data = data.encode('ascii') - bio = backend._bytes_to_bio(data) - ptr = self._ffi.new("X509_REQ **") - ptr[0] = self._csrObj - ret = self._lib.PEM_read_bio_X509_REQ(bio[0], ptr, - self._ffi.NULL, - self._ffi.NULL) - if ret == self._ffi.NULL: - raise X509CsrError("Could not read X509 CSR from PEM data.") + return X509Csr.from_open_file(io.StringIO(data)) - def from_file(self, path, password=None): + @staticmethod + def from_file(path): """Create this CSR from a file on disk :param path: Path to the file on disk - :param password: decryption password, if needed """ - data = None - with open(path, 'rb') as f: - data = f.read() - self.from_buffer(data, password) + with open(path, 'r') as f: + return X509Csr.from_open_file(f) def get_pubkey(self): """Get the public key from the CSR :return: an OpenSSL EVP_PKEY object """ - pkey = self._lib.X509_REQ_get_pubkey(self._csrObj) - if pkey == self._ffi.NULL: - raise X509CsrError( - "Could not get pubkey from X509 CSR.") # pragma: no cover + return self._csr['certificationRequestInfo']['subjectPublicKeyInfo'] - return pkey + def get_request_info(self): + if self._csr['certificationRequestInfo'] is None: + self._csr['certificationRequestInfo'] = None + return self._csr['certificationRequestInfo'] def get_subject(self): """Get the subject name field from the CSR :return: an X509Name object """ - subs = self._lib.X509_REQ_get_subject_name(self._csrObj) - if subs == self._ffi.NULL: - raise X509CsrError( - "Could not get subject from X509 CSR.") # pragma: no cover + ri = self.get_request_info() + if ri['subject'] is None: + ri['subject'] = None + # setup first RDN sequence + ri['subject'][0] = None - return name.X509Name(subs) + subject = ri['subject'][0] + return name.X509Name(subject) - def get_extensions(self): + def get_attributes(self): + ri = self.get_request_info() + if ri['attributes'] is None: + ri['attributes'] = None + return ri['attributes'] + + def get_extensions(self, ext_type=None): """Get the list of all X509 V3 Extensions on this CSR :return: a list of X509Extension objects """ - # TODO(tkelsey): I assume the ext list copies data and this is safe - # TODO(tkelsey): Error checking needed here - ret = [] - exts = self._lib.X509_REQ_get_extensions(self._csrObj) - num = self._lib.sk_X509_EXTENSION_num(exts) - for i in range(0, num): - ext = self._lib.sk_X509_EXTENSION_value(exts, i) - ret.append(certificate.X509Extension(ext)) - self._lib.sk_X509_EXTENSION_free(exts) - return ret + ext_attrs = [a for a in self.get_attributes() + if a['type'] == OID_extensionRequest] + if len(ext_attrs) == 0: + return [] + else: + exts_der = ext_attrs[0]['vals'][0].asOctets() + exts = decoder.decode(exts_der, asn1Spec=rfc2459.Extensions())[0] + return [extension.construct_extension(e) for e in exts + if ext_type is None or e['extnID'] == ext_type._oid] + + def add_extension(self, ext): + if not isinstance(ext, extension.X509Extension): + raise errors.X509Error("ext is not an anchor X509Extension") + attributes = self.get_attributes() + ext_attrs = [a for a in attributes + if a['type'] == OID_extensionRequest] + if not ext_attrs: + new_attr_index = len(attributes) + attributes[new_attr_index] = None + ext_attr = attributes[new_attr_index] + ext_attr['type'] = OID_extensionRequest + ext_attr['vals'] = None + exts = rfc2459.Extensions() + else: + ext_attr = ext_attrs[0] + exts = decoder.decode(ext_attr['vals'][0].asOctets(), + asn1Spec=rfc2459.Extensions())[0] + + new_ext_index = len(exts) + exts[new_ext_index] = ext._ext + + ext_attr['vals'][0] = encoder.encode(exts) diff --git a/anchor/X509/utils.py b/anchor/X509/utils.py index 3eee316..5f227ef 100644 --- a/anchor/X509/utils.py +++ b/anchor/X509/utils.py @@ -15,38 +15,18 @@ from __future__ import absolute_import import calendar import datetime +import struct -from cryptography.hazmat.backends.openssl import backend +from cryptography.hazmat import backends +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives import serialization +import netaddr +from pyasn1.type import useful as asn1_useful +from pyasn1_modules import rfc2459 from anchor.X509 import errors -def load_pem_private_key(key_data, passwd=None): - """Load and return an OpenSSL EVP_PKEY public key object from a data buffer - - :param key_data: The data buffer - :param passwd: Decryption password if neded (not used for now) - :return: an OpenSSL EVP_PKEY public key object - """ - # TODO(tkelsey): look at using backend.read_private_key - # - - if type(key_data) != bytes: - key_data = key_data.encode('ascii') - lib = backend._lib - ffi = backend._ffi - data = backend._bytes_to_bio(key_data) - - evp_pkey = lib.EVP_PKEY_new() - evp_pkey_ptr = ffi.new("EVP_PKEY**") - evp_pkey_ptr[0] = evp_pkey - evp_pkey = lib.PEM_read_bio_PrivateKey(data[0], evp_pkey_ptr, - ffi.NULL, ffi.NULL) - - evp_pkey = ffi.gc(evp_pkey, lib.EVP_PKEY_free) - return evp_pkey - - def create_timezone(minute_offset): """Create a new timezone with a specified offset. @@ -80,18 +60,17 @@ def asn1_time_to_timestamp(t): :param t: ASN1_TIME to convert """ - - gen_time = backend._lib.ASN1_TIME_to_generalizedtime(t, backend._ffi.NULL) - if gen_time == backend._ffi.NULL: - raise errors.ASN1TimeError("time conversion failure") - - try: - return asn1_generalizedtime_to_timestamp(gen_time) - finally: - backend._lib.ASN1_GENERALIZEDTIME_free(gen_time) + component = t.getComponent() + timestring = component.asOctets().decode(component.encoding) + if isinstance(component, asn1_useful.UTCTime): + if int(timestring[0]) >= 5: + timestring = "19" + timestring + else: + timestring = "20" + timestring + return asn1_timestring_to_timestamp(timestring) -def asn1_generalizedtime_to_timestamp(gt): +def asn1_timestring_to_timestamp(timestring): """Convert from ASN1_GENERALIZEDTIME to UTC-based timestamp. :param gt: ASN1_GENERALIZEDTIME to convert @@ -99,11 +78,8 @@ def asn1_generalizedtime_to_timestamp(gt): # ASN1_GENERALIZEDTIME is actually a string in known formats, # so the conversion can be done in this code - string_time = backend._ffi.cast("ASN1_STRING*", gt) - res = asn1_string_to_utf8(string_time) - - before_tz = res[:14] - tz_str = res[14:] + before_tz = timestring[:14] + tz_str = timestring[14:] d = datetime.datetime.strptime(before_tz, "%Y%m%d%H%M%S") if tz_str == 'Z': # YYYYMMDDhhmmssZ @@ -126,25 +102,85 @@ def timestamp_to_asn1_time(t): """ d = datetime.datetime.utcfromtimestamp(t) - # use the ASN1_GENERALIZEDTIME format - time_str = d.strftime("%Y%m%d%H%M%SZ").encode('ascii') - asn1_time = backend._lib.ASN1_STRING_type_new( - backend._lib.V_ASN1_GENERALIZEDTIME) - backend._lib.ASN1_STRING_set(asn1_time, time_str, len(time_str)) - asn1_gentime = backend._ffi.cast("ASN1_GENERALIZEDTIME*", asn1_time) - if backend._lib.ASN1_GENERALIZEDTIME_check(asn1_gentime) == 0: - raise errors.ASN1TimeError("timestamp not accepted by ASN1 check") - - # ASN1_GENERALIZEDTIME is a form of ASN1_TIME, so a pointer cast is valid - return backend._ffi.cast("ASN1_TIME*", asn1_time) + asn1time = rfc2459.Time() + if d.year <= 2049: + time_str = d.strftime("%y%m%d%H%M%SZ").encode('ascii') + asn1time['utcTime'] = time_str + else: + time_str = d.strftime("%Y%m%d%H%M%SZ").encode('ascii') + asn1time['generalTime'] = time_str + return asn1time -def asn1_string_to_utf8(asn1_string): - buf = backend._ffi.new("unsigned char **") - res = backend._lib.ASN1_STRING_to_UTF8(buf, asn1_string) - if res < 0 or buf[0] == backend._ffi.NULL: - raise errors.ASN1StringError("cannot convert asn1 to python string") - buf = backend._ffi.gc( - buf, lambda buffer: backend._lib.OPENSSL_free(buffer[0]) - ) - return backend._ffi.buffer(buf[0], res)[:].decode('utf8') +# functions needed for converting the pyasn1 signature fields +def bin_to_bytes(bits): + """Convert bit string to byte string.""" + bits = ''.join(str(b) for b in bits) + bits = _pad_byte(bits) + octets = [bits[8*i:8*(i+1)] for i in range(len(bits)/8)] + bytes = [chr(int(x, 2)) for x in octets] + return "".join(bytes) + + +# ord good for py2 and py3 +local_ord = ord if str is bytes else lambda x: x + + +def _pad_byte(bits): + """Pad a string of bits with zeros to make its length a multiple of 8.""" + r = len(bits) % 8 + return ((8-r) % 8)*'0' + bits + + +def bytes_to_bin(bytes): + """Convert byte string to bit string.""" + return "".join([_pad_byte(_int_to_bin(local_ord(byte))) for byte in bytes]) + + +def _int_to_bin(n): + if n == 0 or n == 1: + return str(n) + elif n % 2 == 0: + return _int_to_bin(n // 2) + "0" + else: + return _int_to_bin(n // 2) + "1" + + +def get_hash_class(md): + return getattr(hashes, md.upper(), None) + + +def get_private_key_from_bytes(data): + key = serialization.load_pem_private_key( + data, None, backend=backends.default_backend()) + return key + + +def get_private_key_from_file(path): + with open(path, 'rb') as f: + return get_private_key_from_bytes(f.read()) + + +def asn1_to_netaddr(octet_string): + """Translate the ASN1 IP format to netaddr object.""" + if not isinstance(octet_string, rfc2459.univ.OctetString): + raise TypeError("not an OctetString") + + ip_bytes = octet_string.asOctets() + if len(ip_bytes) == 4: + ip_num = struct.unpack(">I", ip_bytes)[0] + return netaddr.IPAddress(ip_num, 4) + elif len(ip_bytes) == 16: + ip_num_front, ip_num_back = struct.unpack(">QQ", ip_bytes) + ip_num = ip_num_front << 64 | ip_num_back + return netaddr.IPAddress(ip_num, 6) + else: + raise TypeError("ip address is neither v4 nor v6") + + +def netaddr_to_asn1(ip): + """Translate the netaddr object to ASN1 IP format.""" + if not isinstance(ip, netaddr.IPAddress): + raise errors.X509Error("not a real ip address provided") + + return bytes(ip.packed) diff --git a/anchor/certificate_ops.py b/anchor/certificate_ops.py index d1fdbef..d789cda 100644 --- a/anchor/certificate_ops.py +++ b/anchor/certificate_ops.py @@ -25,7 +25,7 @@ from anchor import jsonloader from anchor import validators from anchor.X509 import certificate from anchor.X509 import signing_request -from anchor.X509 import utils as X509_utils +from anchor.X509 import utils logger = logging.getLogger(__name__) @@ -54,8 +54,7 @@ def parse_csr(csr, encoding): # load the CSR into the backend X509 library try: - out_req = signing_request.X509Csr() - out_req.from_buffer(csr) + out_req = signing_request.X509Csr.from_buffer(csr) return out_req except Exception as e: logger.exception("Exception while parsing the CSR: %s", e) @@ -132,17 +131,14 @@ def sign(csr): :param csr: X509 certificate signing request """ try: - ca = certificate.X509Certificate() - ca.from_file(jsonloader.conf.ca["cert_path"]) + ca = certificate.X509Certificate.from_file( + jsonloader.conf.ca["cert_path"]) except Exception as e: logger.exception("Cannot load the signing CA: %s", e) pecan.abort(500, "certificate signing error") try: - key_data = None - with open(jsonloader.conf.ca["key_path"]) as f: - key_data = f.read() - key = X509_utils.load_pem_private_key(key_data) + key = utils.get_private_key_from_file(jsonloader.conf.ca['key_path']) except Exception as e: logger.exception("Cannot load the signing CA key: %s", e) pecan.abort(500, "certificate signing error") @@ -182,7 +178,7 @@ def sign(csr): cert_pem = new_cert.as_pem() - with open(path, "wb") as f: + with open(path, "w") as f: f.write(cert_pem) return cert_pem diff --git a/anchor/validators.py b/anchor/validators.py index ea61950..67abaae 100644 --- a/anchor/validators.py +++ b/anchor/validators.py @@ -17,6 +17,7 @@ import logging import netaddr +from anchor.X509 import extension from anchor.X509 import name as x509_name @@ -29,7 +30,7 @@ class ValidationError(Exception): def csr_get_cn(csr): name = csr.get_subject() - data = name.get_entries_by_nid(x509_name.NID_commonName) + data = name.get_entries_by_oid(x509_name.OID_commonName) if len(data) > 0: return data[0].get_value() else: @@ -51,19 +52,14 @@ def check_domains(domain, allowed_domains): def iter_alternative_names(csr, types, fail_other_types=True): for ext in csr.get_extensions(): - if ext.get_name() == "subjectAltName": - alternatives = [alt.strip() for alt in ext.get_value().split(',')] - for alternative in alternatives: - parts = alternative.split(':', 1) - if len(parts) != 2: - # it has at least one part, so parts[0] is valid - raise ValidationError("Alt name should have 2 parts, but " - "found: '%s'" % parts[0]) - if parts[0] in types: - yield parts - elif fail_other_types: - raise ValidationError("Alt name '%s' has unexpected type " - "'%s'" % (parts[1], parts[0])) + if isinstance(ext, extension.X509ExtensionSubjectAltName): + # TODO(stan): fail on other types + if 'DNS' in types: + for dns_id in ext.get_dns_ids(): + yield ('DNS', dns_id) + if 'IP Address' in types: + for ip in ext.get_ips(): + yield ('IP Address', ip) def check_networks(ip, allowed_networks): @@ -91,15 +87,15 @@ def common_name(csr, allowed_domains=[], allowed_networks=[], **kwargs): alt_present = any(ext.get_name() == "subjectAltName" for ext in csr.get_extensions()) - CNs = csr.get_subject().get_entries_by_nid(x509_name.NID_commonName) + CNs = csr.get_subject().get_entries_by_oid(x509_name.OID_commonName) if len(CNs) > 1: raise ValidationError("Too many CNs in the request") - if not alt_present: - # rfc5280#section-4.2.1.6 says so - if len(CNs) == 0: - raise ValidationError("Alt subjects have to exist if the main" - " subject doesn't") + + # rfc5280#section-4.2.1.6 says so + if len(CNs) == 0 and not alt_present: + raise ValidationError("Alt subjects have to exist if the main" + " subject doesn't") if len(CNs) > 0: cn = csr_get_cn(csr) @@ -122,7 +118,7 @@ def alternative_names(csr, allowed_domains=[], **kwargs): the list of known suffixes, or network ranges. """ - for name_type, name in iter_alternative_names(csr, ['DNS']): + for _, name in iter_alternative_names(csr, ['DNS']): if not check_domains(name, allowed_domains): raise ValidationError("Domain '%s' not allowed (doesn't" " match known domains)" @@ -142,9 +138,8 @@ def alternative_names_ip(csr, allowed_domains=[], allowed_networks=[], raise ValidationError("Domain '%s' not allowed (doesn't" " match known domains)" % name) if name_type == 'IP Address': - ip = netaddr.IPAddress(name) - if not check_networks(ip, allowed_networks): - raise ValidationError("Address '%s' not allowed (doesn't" + if not check_networks(name, allowed_networks): + raise ValidationError("IP '%s' not allowed (doesn't" " match known networks)" % name) @@ -156,7 +151,7 @@ def blacklist_names(csr, domains=[], **kwargs): "consider disabling the step or providing a list") return - CNs = csr.get_subject().get_entries_by_nid(x509_name.NID_commonName) + CNs = csr.get_subject().get_entries_by_oid(x509_name.OID_commonName) if len(CNs) > 0: cn = csr_get_cn(csr) if check_domains(cn, domains): @@ -198,45 +193,41 @@ def extensions(csr=None, allowed_extensions=[], **kwargs): def key_usage(csr=None, allowed_usage=None, **kwargs): """Ensure only accepted key usages are specified.""" - allowed = set(allowed_usage) + allowed = set(extension.LONG_KEY_USAGE_NAMES.get(x, x) for x in + allowed_usage) + denied = set() for ext in (csr.get_extensions() or []): - if ext.get_name() == 'keyUsage': - usages = set(usage.strip() for usage in ext.get_value().split(',')) - if usages & allowed != usages: - raise ValidationError("Found some not allowed key usages: %s" - % ', '.join(usages - allowed)) + if isinstance(ext, extension.X509ExtensionKeyUsage): + usages = set(ext.get_all_usages()) + denied = denied | (usages - allowed) + if denied: + raise ValidationError("Found some not allowed key usages: %s" + % ', '.join(denied)) def ca_status(csr=None, ca_requested=False, **kwargs): """Ensure the request has/hasn't got the CA flag.""" - + request_ca_flags = False for ext in (csr.get_extensions() or []): - ext_name = ext.get_name() - if ext_name == 'basicConstraints': - options = [opt.strip() for opt in ext.get_value().split(",")] - for option in options: - parts = option.split(":") - if len(parts) != 2: - raise ValidationError("Invalid basic constraints flag") - - if parts[0] == 'CA': - if parts[1] != str(ca_requested).upper(): - raise ValidationError("Invalid CA status, 'CA:%s'" - " requested" % parts[1]) - elif parts[0] == 'pathlen': - # errr.. it's ok, I guess - pass - else: - raise ValidationError("Invalid basic constraints option") - elif ext_name == 'keyUsage': - usages = set(usage.strip() for usage in ext.get_value().split(',')) - has_cert_sign = ('Certificate Sign' in usages) - has_crl_sign = ('CRL Sign' in usages) - if ca_requested != has_cert_sign or ca_requested != has_crl_sign: - raise ValidationError("Key usage doesn't match requested CA" - " status (keyCertSign/cRLSign: %s/%s)" - % (has_cert_sign, has_crl_sign)) + if isinstance(ext, extension.X509ExtensionBasicConstraints): + if ext.get_ca(): + if not ca_requested: + raise ValidationError( + "CA status requested, but not allowed") + request_ca_flags = True + elif isinstance(ext, extension.X509ExtensionKeyUsage): + has_cert_sign = ext.get_usage('keyCertSign') + has_crl_sign = ext.get_usage('cRLSign') + if has_crl_sign or has_cert_sign: + if not ca_requested: + raise ValidationError( + "Key usage doesn't match requested CA status " + "(keyCertSign/cRLSign: %s/%s)" + % (has_cert_sign, has_crl_sign)) + request_ca_flags = True + if ca_requested and not request_ca_flags: + raise ValidationError("CA flags required") def source_cidrs(request=None, cidrs=None, **kwargs): diff --git a/requirements.txt b/requirements.txt index c82c2ce..cc86856 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,8 @@ # of appearance. Changing the order has an impact on the overall integration # process, which may cause wedges in the gate later. cryptography>=0.9.1 # Apache-2.0 +pyasn1 +pyasn1_modules pecan>=0.8.0 Paste netaddr>=0.7.12 diff --git a/tests/X509/test_extension.py b/tests/X509/test_extension.py new file mode 100644 index 0000000..000a8c2 --- /dev/null +++ b/tests/X509/test_extension.py @@ -0,0 +1,144 @@ +# -*- coding:utf-8 -*- +# +# Copyright 2014 Hewlett-Packard Development Company, L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +import netaddr +from pyasn1_modules import rfc2459 # X509v3 + +from anchor.X509 import errors +from anchor.X509 import extension + + +class TestExtensionBase(unittest.TestCase): + def test_no_spec(self): + with self.assertRaises(errors.X509Error): + extension.X509Extension() + + def test_invalid_asn(self): + with self.assertRaises(errors.X509Error): + extension.X509Extension("foobar") + + def test_unknown_extension_str(self): + asn1 = rfc2459.Extension() + asn1['extnID'] = rfc2459.univ.ObjectIdentifier('1.2.3.4') + asn1['critical'] = False + asn1['extnValue'] = "foobar" + ext = extension.X509Extension(asn1) + self.assertEqual("1.2.3.4: ", str(ext)) + + def test_construct(self): + asn1 = rfc2459.Extension() + asn1['extnID'] = rfc2459.univ.ObjectIdentifier('1.2.3.4') + asn1['critical'] = False + asn1['extnValue'] = "foobar" + ext = extension.construct_extension(asn1) + self.assertIsInstance(ext, extension.X509Extension) + + def test_construct_invalid_type(self): + with self.assertRaises(errors.X509Error): + extension.construct_extension("foobar") + + def test_critical(self): + asn1 = rfc2459.Extension() + asn1['extnID'] = rfc2459.univ.ObjectIdentifier('1.2.3.4') + asn1['critical'] = False + asn1['extnValue'] = "foobar" + ext = extension.construct_extension(asn1) + self.assertFalse(ext.get_critical()) + ext.set_critical(True) + self.assertTrue(ext.get_critical()) + + +class TestBasicConstraints(unittest.TestCase): + def setUp(self): + self.ext = extension.X509ExtensionBasicConstraints() + + def test_str(self): + self.assertEqual(str(self.ext), + "basicConstraints: CA: FALSE, pathLen: None") + + def test_ca(self): + self.ext.set_ca(True) + self.assertTrue(self.ext.get_ca()) + self.ext.set_ca(False) + self.assertFalse(self.ext.get_ca()) + + def test_pathlen(self): + self.ext.set_path_len_constraint(1) + self.assertEqual(1, self.ext.get_path_len_constraint()) + + +class TestKeyUsage(unittest.TestCase): + def setUp(self): + self.ext = extension.X509ExtensionKeyUsage() + + def test_usage_set(self): + self.ext.set_usage('digitalSignature', True) + self.ext.set_usage('keyAgreement', False) + self.assertTrue(self.ext.get_usage('digitalSignature')) + self.assertFalse(self.ext.get_usage('keyAgreement')) + + def test_usage_reset(self): + self.ext.set_usage('digitalSignature', True) + self.ext.set_usage('digitalSignature', False) + self.assertFalse(self.ext.get_usage('digitalSignature')) + + def test_usage_unset(self): + self.assertFalse(self.ext.get_usage('keyAgreement')) + + def test_get_all_usage(self): + self.ext.set_usage('digitalSignature', True) + self.ext.set_usage('keyAgreement', False) + self.ext.set_usage('keyEncipherment', True) + self.assertEqual(set(['digitalSignature', 'keyEncipherment']), + set(self.ext.get_all_usages())) + + def test_str(self): + self.ext.set_usage('digitalSignature', True) + self.assertEqual("keyUsage: digitalSignature", str(self.ext)) + + +class TestSubjectAltName(unittest.TestCase): + def setUp(self): + self.ext = extension.X509ExtensionSubjectAltName() + self.domain = 'some.domain' + self.ip = netaddr.IPAddress('1.2.3.4') + self.ip6 = netaddr.IPAddress('::1') + + def test_dns_ids(self): + self.ext.add_dns_id(self.domain) + self.ext.add_ip(self.ip) + self.assertEqual([self.domain], self.ext.get_dns_ids()) + + def test_ips(self): + self.ext.add_dns_id(self.domain) + self.ext.add_ip(self.ip) + self.assertEqual([self.ip], self.ext.get_ips()) + + def test_ipv6(self): + self.ext.add_ip(self.ip6) + self.assertEqual([self.ip6], self.ext.get_ips()) + + def test_add_ip_invalid(self): + with self.assertRaises(errors.X509Error): + self.ext.add_ip("abcdef") + + def test_str(self): + self.ext.add_dns_id(self.domain) + self.ext.add_ip(self.ip) + self.assertEqual("subjectAltName: DNS:some.domain, IP:1.2.3.4", + str(self.ext)) diff --git a/tests/X509/test_message_digest.py b/tests/X509/test_message_digest.py deleted file mode 100644 index 2cdb0bf..0000000 --- a/tests/X509/test_message_digest.py +++ /dev/null @@ -1,92 +0,0 @@ -# -*- coding:utf-8 -*- -# -# Copyright 2014 Hewlett-Packard Development Company, L.P. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import unittest - -from anchor.X509 import message_digest - - -class TestMessageDigest(unittest.TestCase): - data = b"this is test data to test with" - - def setUp(self): - super(TestMessageDigest, self).setUp() - - def tearDown(self): - super(TestMessageDigest, self).tearDown() - - def test_bad_algo(self): - self.assertRaises(message_digest.MessageDigestError, - message_digest.MessageDigest, - 'BAD') - - def test_md5(self): - v = "B2F81E9F287884AF6A8B3E8EFB96C711" - md = message_digest.MessageDigest("md5") - md.update(TestMessageDigest.data) - ret = md.final() - self.assertEqual(ret, v) - - def test_ripmed160(self): - v = "BA5CCC4574D676266D821269CA77BFFD7FD9FCB0" - md = message_digest.MessageDigest("ripemd160") - md.update(TestMessageDigest.data) - ret = md.final() - self.assertEqual(ret, v) - - def test_sha224(self): - v = "675170C12E88D549DB0F608AD6857103D7B792F29FACFCC53173F178" - md = message_digest.MessageDigest("sha224") - md.update(TestMessageDigest.data) - ret = md.final() - self.assertEqual(ret, v) - - def test_sha256(self): - v = "91F672E796E84BECC6F051A47D7392BD789AEA7D55090588F212CF041C862678" - md = message_digest.MessageDigest("sha256") - md.update(TestMessageDigest.data) - ret = md.final() - self.assertEqual(ret, v) - - def test_sha384(self): - v = ("9667AF42DF2E6B81EE679757BB207A3F9BB7CED49CF838FF3ED8237C9B15291B" - "15") - md = message_digest.MessageDigest("sha384") - md.update(TestMessageDigest.data) - ret = md.final() - self.assertEqual(ret, v) - - def test_sha512(self): - v = ("283B3ECD8AE687226C3EA46B59F65E5CA50A11735C9C14BED11F0CCB515707B5" - "1031145ED8AE4B35B24B91F26E70AC0ACAC37B5BEE933B28834FE6447D1298CB" - ) - md = message_digest.MessageDigest("sha512") - md.update(TestMessageDigest.data) - ret = md.final() - self.assertEqual(ret, v) - - def test_algorithms(self): - algs = [ - "md5", - "ripemd160", - "sha224", - "sha256", - "sha384", - "sha512" - ] - valid = message_digest.MessageDigest.getValidAlgorithms() - for alg in algs: - self.assertTrue(alg in valid) diff --git a/tests/X509/test_utils.py b/tests/X509/test_utils.py index d617f3b..5829d0e 100644 --- a/tests/X509/test_utils.py +++ b/tests/X509/test_utils.py @@ -14,70 +14,21 @@ # License for the specific language governing permissions and limitations # under the License. -import datetime import unittest -import mock - -from anchor.X509 import errors from anchor.X509 import utils -from cryptography.hazmat.backends.openssl import backend - - -class TestASN1String(unittest.TestCase): - # missing in cryptography.io - V_ASN1_UTF8STRING = 12 - - def test_utf8_string(self): - orig = u"test \u2603 snowman" - encoded = orig.encode('utf-8') - asn1string = backend._lib.ASN1_STRING_type_new(self.V_ASN1_UTF8STRING) - backend._lib.ASN1_STRING_set(asn1string, encoded, len(encoded)) - - res = utils.asn1_string_to_utf8(asn1string) - self.assertEqual(res, orig) - - def test_invalid_string(self): - encoded = b"\xff" - asn1string = backend._lib.ASN1_STRING_type_new(self.V_ASN1_UTF8STRING) - backend._lib.ASN1_STRING_set(asn1string, encoded, len(encoded)) - - self.assertRaises(errors.ASN1StringError, utils.asn1_string_to_utf8, - asn1string) - class TestASN1Time(unittest.TestCase): - def test_conversion_failure(self): - with mock.patch.object(backend._lib, "ASN1_TIME_to_generalizedtime", - return_value=backend._ffi.NULL): - t = utils.timestamp_to_asn1_time(0) - self.assertRaises(errors.ASN1TimeError, - utils.asn1_time_to_timestamp, t) + def test_round_check(self): + t = 0 + asn1_time = utils.timestamp_to_asn1_time(t) + res = utils.asn1_time_to_timestamp(asn1_time) + self.assertEqual(t, res) - def test_generalizedtime_check_failure(self): - with mock.patch.object(backend._lib, "ASN1_GENERALIZEDTIME_check", - return_value=0): - self.assertRaises(errors.ASN1TimeError, - utils.timestamp_to_asn1_time, 0) - - -class TestTimezone(unittest.TestCase): - def test_utcoffset(self): - tz = utils.create_timezone(1234) - offset = tz.utcoffset(datetime.datetime.now()) - self.assertEqual(datetime.timedelta(minutes=1234), offset) - - def test_dst(self): - tz = utils.create_timezone(1234) - offset = tz.dst(datetime.datetime.now()) - self.assertEqual(datetime.timedelta(0), offset) - - def test_name(self): - tz = utils.create_timezone(1234) - name = tz.tzname(datetime.datetime.now()) - self.assertIsNone(name) - - def test_repr(self): - tz = utils.create_timezone(1234) - self.assertEqual("Timezone +2034", repr(tz)) + def test_post_2050(self): + """Test date post 2050, which causes different encoding.""" + t = 2600000000 + asn1_time = utils.timestamp_to_asn1_time(t) + res = utils.asn1_time_to_timestamp(asn1_time) + self.assertEqual(t, res) diff --git a/tests/X509/test_x509_certificate.py b/tests/X509/test_x509_certificate.py index 385a83c..1db5e1d 100644 --- a/tests/X509/test_x509_certificate.py +++ b/tests/X509/test_x509_certificate.py @@ -18,25 +18,18 @@ import unittest import mock -import sys +import io import textwrap from anchor.X509 import certificate from anchor.X509 import errors as x509_errors +from anchor.X509 import extension from anchor.X509 import name as x509_name - - -# find the class representing an open file; it depends on the python version -# it's used later for mocking -if sys.version_info[0] < 3: - file_class = file # noqa -else: - import _io - file_class = _io.TextIOWrapper +from anchor.X509 import utils class TestX509Cert(unittest.TestCase): - cert_data = textwrap.dedent(""" + cert_data = textwrap.dedent(u""" -----BEGIN CERTIFICATE----- MIICKjCCAZOgAwIBAgIIfeW6dwGe6wMwDQYJKoZIhvcNAQEFBQAwUjELMAkGA1UE BhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxFjAUBgNVBAoTDUhlcnAgRGVycCBw @@ -52,17 +45,70 @@ class TestX509Cert(unittest.TestCase): gTLni27WuVJFVBNoTU1JfoxBSm/RBLdTj92g9N5g -----END CERTIFICATE-----""") + key_dsa_data = textwrap.dedent(""" + -----BEGIN DSA PARAMETERS----- + MIICLAKCAQEA59W1OsK9Tv7DRbxzibGVpBAL2Oz8JhbV3ii7WAat+UfTBLAnfdva + 7UE8odu1l8p41N/8H/tDWgPh6tOgdX0YT9HDsILymQxzUEscliFZKmYg7YdSH3Zd + 6DglOT7CqYxX0r9gK/BOh8ESe3gqKncnThHnO8Eu9wP8HNcrN00EOqP+fJpbS0lu + iifD9JdFY5YpCsLDIvpPbM0NCDuANPo10N3qqC8BuNiu0VfZpRSBcqzU1kwABT5n + y7+8RMh5Xaa7xnhGctJ9s9n+QfWcF/vbgiDOBttb3d8r8Pqvoou8v7Q38Q6zILhf + hajevqjGqZwodbvbHGfFbWapgBjpBIr4zwIhAOq6uryEHQglirWCGFJLQlkzxghy + ctHBRXGuKYb+ltRTAoIBAHRUFxzd1vhjKQ5atIdG0AiXUNm7/uboe21EJDLf4lkE + 7UHDZfwsHXxQHfozzIsp7gHcw7F6AVCgiNRi9vBYOemPswevoWiVKqLTVt1wMogD + EJI6VAQEbBmSrtvyuClCkEAlIY6daX9EV9KqbnetS4/xv4WFQ9FPE47VyQ50vvxK + JSyNZnJ1lN6FUD9R5YYfwERgND8EYJBD10UBKIvtORICTJUfaDAweTWhaVcXUID7 + VGNGPauOdVQzWsWTrQn/f/hbXCB/KXgv1l92D6rEoT2j2YrqIv/qD/ZxPwhBfLdr + W241Cb+LT05LVCokRbWUdjfuO8SdSBAIvT9P6umG/uQ= + -----END DSA PARAMETERS----- + -----BEGIN DSA PRIVATE KEY----- + MIIDVwIBAAKCAQEA59W1OsK9Tv7DRbxzibGVpBAL2Oz8JhbV3ii7WAat+UfTBLAn + fdva7UE8odu1l8p41N/8H/tDWgPh6tOgdX0YT9HDsILymQxzUEscliFZKmYg7YdS + H3Zd6DglOT7CqYxX0r9gK/BOh8ESe3gqKncnThHnO8Eu9wP8HNcrN00EOqP+fJpb + S0luiifD9JdFY5YpCsLDIvpPbM0NCDuANPo10N3qqC8BuNiu0VfZpRSBcqzU1kwA + BT5ny7+8RMh5Xaa7xnhGctJ9s9n+QfWcF/vbgiDOBttb3d8r8Pqvoou8v7Q38Q6z + ILhfhajevqjGqZwodbvbHGfFbWapgBjpBIr4zwIhAOq6uryEHQglirWCGFJLQlkz + xghyctHBRXGuKYb+ltRTAoIBAHRUFxzd1vhjKQ5atIdG0AiXUNm7/uboe21EJDLf + 4lkE7UHDZfwsHXxQHfozzIsp7gHcw7F6AVCgiNRi9vBYOemPswevoWiVKqLTVt1w + MogDEJI6VAQEbBmSrtvyuClCkEAlIY6daX9EV9KqbnetS4/xv4WFQ9FPE47VyQ50 + vvxKJSyNZnJ1lN6FUD9R5YYfwERgND8EYJBD10UBKIvtORICTJUfaDAweTWhaVcX + UID7VGNGPauOdVQzWsWTrQn/f/hbXCB/KXgv1l92D6rEoT2j2YrqIv/qD/ZxPwhB + fLdrW241Cb+LT05LVCokRbWUdjfuO8SdSBAIvT9P6umG/uQCggEBAKrZAppbnKf1 + pzSvE3gTaloitAJG+79BML5h1n67EWuv0i+Fq4eUAVJ23R8GR1HrYw6utZoYbu8u + k8eHrArMfTfbFaLwK/Nv33Hfm3aTTXnY6auLNkpbiZXuCQjWBFhb6F+B42V9/JJ8 + RJ1UV6Y2ajjjMvpeh0cPlARw5UpKBgQ933DhefCWyFBPsPToFvd3uPO+GUN6VpNY + iR7G0AH3/LSVJRuz5/QCp86uLIoU3fBEf1KGYJrkVKlc9DtcNmDXgpP0d3fK+4Jw + bGvi5AD1sQOWryNujyS/d2K/PAagsD0M6XJFgkEV592OSlygbYtuo3t4AtAy8F0f + VHNXq2l01FMCIQCrkk1749eQg4W6j7HfLFvjbDcuIFTw98IKyEZuZ93cdA== + -----END DSA PRIVATE KEY-----""").encode('ascii') + + key_rsa_data = textwrap.dedent(""" + -----BEGIN RSA PRIVATE KEY----- + MIICXAIBAAKBgQCeeqg1Qeccv8hqj1BP9KEJX5QsFCxR62M8plPb5t4sLo8UYfZd + 6kFLcOP8xzwwvx/eFY6Sux52enQ197o8aMwyP77hMhZqtd8NCgLJMVlUbRhwLti0 + SkHFPic0wAg+esfXa6yhd5TxC+bti7MgV/ljA80XQxHH8xOjdOoGN0DHfQIDAQAB + AoGBAJ2ozJpe+7qgGJPaCz3f0izvBwtq7kR49fqqRZbo8HHnx7OxWVVI7LhOkKEy + 2/Bq0xsvOu1CdiXL4LynvIDIiQqLaeINzG48Rbk+0HadbXblt3nDkIWdYII6zHKI + W9ewX4KpHEPbrlEO9BjAlAcYsDIvFIMYpQhtQ+0R/gmZ99WJAkEAz5C2a6FIcMbE + o3aTc9ECq99zY7lxh+6aLpUdIeeHyb/QzfGDBdlbpBAkA6EcxSqp0aqH4xIQnYHa + 3P5ZCShqSwJBAMN1sb76xq94xkg2cxShPFPAE6xKRFyKqLgsBYVtulOdfOtOnjh9 + 1SK2XQQfBRIRdG4Q/gDoCP8XQHpJcWMk+FcCQDnuJqulaOVo5GrG5mJ1nCxCAh98 + G06X7lo/7dCPoRtSuMExvaK9RlFk29hTeAcjYCAPWzupyA9dtarmJg1jRT8CQCKf + gYnb8D/6+9yk0IPR/9ayCooVacCeyz48hgnZowzWs98WwQ4utAd/GED3obVOpDov + Bl9wus889i3zPoOac+cCQCZHredQcJGd4dlthbVtP2NhuPXz33JuETGR9pXtsDUZ + uX/nSq1oo9kUh/dPOz6aP5Ues1YVe3LExmExPBQfwIE= + -----END RSA PRIVATE KEY-----""").encode('ascii') + def setUp(self): super(TestX509Cert, self).setUp() - self.cert = certificate.X509Certificate() - self.cert.from_buffer(TestX509Cert.cert_data) + self.cert = certificate.X509Certificate.from_buffer( + TestX509Cert.cert_data) def tearDown(self): pass def test_bad_data_throws(self): bad_data = ( - "some bad data is " + u"some bad data is " "EHRlc3RAYW5jaG9yLnRlc3QwTDANBgkqhkiG9w0BAQEFAAM7ADA4AjEA6m") cert = certificate.X509Certificate() @@ -72,119 +118,121 @@ class TestX509Cert(unittest.TestCase): def test_get_subject_countryName(self): name = self.cert.get_subject() - entries = name.get_entries_by_nid(x509_name.NID_countryName) + entries = name.get_entries_by_oid(x509_name.OID_countryName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "countryName") self.assertEqual(entries[0].get_value(), "UK") def test_get_subject_stateOrProvinceName(self): name = self.cert.get_subject() - entries = name.get_entries_by_nid(x509_name.NID_stateOrProvinceName) + entries = name.get_entries_by_oid(x509_name.OID_stateOrProvinceName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "stateOrProvinceName") self.assertEqual(entries[0].get_value(), "Narnia") def test_get_subject_localityName(self): name = self.cert.get_subject() - entries = name.get_entries_by_nid(x509_name.NID_localityName) + entries = name.get_entries_by_oid(x509_name.OID_localityName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "localityName") self.assertEqual(entries[0].get_value(), "Funkytown") def test_get_subject_organizationName(self): name = self.cert.get_subject() - entries = name.get_entries_by_nid(x509_name.NID_organizationName) + entries = name.get_entries_by_oid(x509_name.OID_organizationName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "organizationName") self.assertEqual(entries[0].get_value(), "Anchor Testing") def test_get_subject_organizationUnitName(self): name = self.cert.get_subject() - entries = name.get_entries_by_nid(x509_name.NID_organizationalUnitName) + entries = name.get_entries_by_oid(x509_name.OID_organizationalUnitName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "organizationalUnitName") self.assertEqual(entries[0].get_value(), "testing") def test_get_subject_commonName(self): name = self.cert.get_subject() - entries = name.get_entries_by_nid(x509_name.NID_commonName) + entries = name.get_entries_by_oid(x509_name.OID_commonName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "commonName") self.assertEqual(entries[0].get_value(), "anchor.test") def test_get_subject_emailAddress(self): name = self.cert.get_subject() - entries = name.get_entries_by_nid(x509_name.NID_pkcs9_emailAddress) + entries = name.get_entries_by_oid(x509_name.OID_pkcs9_emailAddress) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "emailAddress") self.assertEqual(entries[0].get_value(), "test@anchor.test") def test_get_issuer_countryName(self): name = self.cert.get_issuer() - entries = name.get_entries_by_nid(x509_name.NID_countryName) + entries = name.get_entries_by_oid(x509_name.OID_countryName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "countryName") self.assertEqual(entries[0].get_value(), "AU") def test_get_issuer_stateOrProvinceName(self): name = self.cert.get_issuer() - entries = name.get_entries_by_nid(x509_name.NID_stateOrProvinceName) + entries = name.get_entries_by_oid(x509_name.OID_stateOrProvinceName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "stateOrProvinceName") self.assertEqual(entries[0].get_value(), "Some-State") def test_get_issuer_organizationName(self): name = self.cert.get_issuer() - entries = name.get_entries_by_nid(x509_name.NID_organizationName) + entries = name.get_entries_by_oid(x509_name.OID_organizationName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "organizationName") self.assertEqual(entries[0].get_value(), "Herp Derp plc") def test_get_issuer_commonName(self): name = self.cert.get_issuer() - entries = name.get_entries_by_nid(x509_name.NID_commonName) + entries = name.get_entries_by_oid(x509_name.OID_commonName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "commonName") self.assertEqual(entries[0].get_value(), "herp.derp.plc") def test_set_subject(self): name = x509_name.X509Name() - name.add_name_entry(x509_name.NID_countryName, 'UK') + name.add_name_entry(x509_name.OID_countryName, 'UK') self.cert.set_subject(name) name = self.cert.get_subject() - entries = name.get_entries_by_nid(x509_name.NID_countryName) + entries = name.get_entries_by_oid(x509_name.OID_countryName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "countryName") self.assertEqual(entries[0].get_value(), "UK") def test_set_issuer(self): name = x509_name.X509Name() - name.add_name_entry(x509_name.NID_countryName, 'UK') + name.add_name_entry(x509_name.OID_countryName, 'UK') self.cert.set_issuer(name) name = self.cert.get_issuer() - entries = name.get_entries_by_nid(x509_name.NID_countryName) + entries = name.get_entries_by_oid(x509_name.OID_countryName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "countryName") self.assertEqual(entries[0].get_value(), "UK") def test_read_from_file(self): open_name = 'anchor.X509.certificate.open' + f = io.StringIO(TestX509Cert.cert_data) with mock.patch(open_name, create=True) as mock_open: - mock_open.return_value = mock.MagicMock(spec=file_class) - m_file = mock_open.return_value.__enter__.return_value - m_file.read.return_value = TestX509Cert.cert_data + mock_open.return_value = f - cert = certificate.X509Certificate() - cert.from_file("some_path") + cert = certificate.X509Certificate.from_file("some_path") name = cert.get_subject() - entries = name.get_entries_by_nid(x509_name.NID_countryName) + entries = name.get_entries_by_oid(x509_name.OID_countryName) self.assertEqual(entries[0].get_value(), "UK") def test_get_fingerprint(self): fp = self.cert.get_fingerprint() - self.assertEqual(fp, "56D61AC583BDDD4B44EEB479EF6C998F") + self.assertEqual(fp, "634A8CD10C81F1CD7A7E140921B4D9CA") + + def test_get_fingerprint_invalid_hash(self): + with self.assertRaises(x509_errors.X509Error): + self.cert.get_fingerprint('no_such_hash') def test_sign_bad_md(self): self.assertRaises(x509_errors.X509Error, @@ -194,7 +242,7 @@ class TestX509Cert(unittest.TestCase): def test_sign_bad_key(self): self.assertRaises(x509_errors.X509Error, self.cert.sign, - self.cert._ffi.NULL) + None) def test_get_version(self): v = self.cert.get_version() @@ -222,3 +270,40 @@ class TestX509Cert(unittest.TestCase): self.cert.set_not_after(0) # seconds since epoch val = self.cert.get_not_after() self.assertEqual(0, val) + + def test_get_extensions(self): + exts = self.cert.get_extensions() + self.assertEqual(2, len(exts)) + + def test_add_extensions(self): + bc = extension.X509ExtensionBasicConstraints() + self.cert.add_extension(bc, 2) + exts = self.cert.get_extensions() + self.assertEqual(3, len(exts)) + + def test_add_extensions_invalid(self): + with self.assertRaises(x509_errors.X509Error): + self.cert.add_extension("abcdef", 2) + + def test_sign_rsa_sha1(self): + key = utils.get_private_key_from_bytes(self.key_rsa_data) + self.cert.sign(key, 'sha1') + self.assertEqual(self.cert.get_fingerprint(), + "BA1B5C97D68EAE738FD10657E6F0B143") + + def test_sign_dsa_sha1(self): + key = utils.get_private_key_from_bytes(self.key_dsa_data) + self.cert.sign(key, 'sha1') + # TODO(stan): add verification; DSA signatures are not + # deterministic which means right now we can only make sure it + # doesn't raise exceptions + + def test_sign_unknown_key(self): + key = object() + with self.assertRaises(x509_errors.X509Error): + self.cert.sign(key, 'sha1') + + def test_sign_unknown_hash(self): + key = utils.get_private_key_from_bytes(self.key_rsa_data) + with self.assertRaises(x509_errors.X509Error): + self.cert.sign(key, 'no_such_hash') diff --git a/tests/X509/test_x509_csr.py b/tests/X509/test_x509_csr.py index c63c028..27d8326 100644 --- a/tests/X509/test_x509_csr.py +++ b/tests/X509/test_x509_csr.py @@ -14,29 +14,21 @@ # License for the specific language governing permissions and limitations # under the License. -import sys +import io import textwrap import unittest -from cryptography.hazmat.backends.openssl import backend import mock +from pyasn1_modules import rfc2459 from anchor.X509 import errors as x509_errors +from anchor.X509 import extension from anchor.X509 import name as x509_name from anchor.X509 import signing_request -# find the class representing an open file; it depends on the python version -# it's used later for mocking -if sys.version_info[0] < 3: - file_class = file # noqa -else: - import _io - file_class = _io.TextIOWrapper - - class TestX509Csr(unittest.TestCase): - csr_data = textwrap.dedent(""" + csr_data = textwrap.dedent(u""" -----BEGIN CERTIFICATE REQUEST----- MIIBWTCCARMCAQAwgZQxCzAJBgNVBAYTAlVLMQ8wDQYDVQQIEwZOYXJuaWExEjAQ BgNVBAcTCUZ1bmt5dG93bjEXMBUGA1UEChMOQW5jaG9yIFRlc3RpbmcxEDAOBgNV @@ -50,41 +42,53 @@ class TestX509Csr(unittest.TestCase): def setUp(self): super(TestX509Csr, self).setUp() - self.csr = signing_request.X509Csr() - self.csr.from_buffer(TestX509Csr.csr_data) + self.csr = signing_request.X509Csr.from_buffer(TestX509Csr.csr_data) def tearDown(self): pass - def test_get_pubkey_bits(self): - # some OpenSSL gumph to test a reasonable attribute of the pubkey + def test_get_pubkey(self): pubkey = self.csr.get_pubkey() - size = backend._lib.EVP_PKEY_bits(pubkey) - self.assertEqual(size, 384) + self.assertEqual(pubkey['algorithm']['algorithm'], + rfc2459.rsaEncryption) def test_get_extensions(self): exts = self.csr.get_extensions() self.assertEqual(len(exts), 2) - self.assertEqual(str(exts[0]), "basicConstraints CA:FALSE") - self.assertEqual(str(exts[1]), ("keyUsage Digital Signature, Non " - "Repudiation, Key Encipherment")) + self.assertFalse(exts[0].get_ca()) + self.assertIsNone(exts[0].get_path_len_constraint()) + self.assertTrue(exts[1].get_usage('digitalSignature')) + self.assertTrue(exts[1].get_usage('nonRepudiation')) + self.assertTrue(exts[1].get_usage('keyEncipherment')) + self.assertFalse(exts[1].get_usage('cRLSign')) + + def test_add_extension(self): + csr = signing_request.X509Csr() + bc = extension.X509ExtensionBasicConstraints() + csr.add_extension(bc) + self.assertEqual(1, len(csr.get_extensions())) + csr.add_extension(bc) + self.assertEqual(2, len(csr.get_extensions())) + + def test_add_extension_invalid_type(self): + csr = signing_request.X509Csr() + with self.assertRaises(x509_errors.X509Error): + csr.add_extension(1234) def test_read_from_file(self): open_name = 'anchor.X509.signing_request.open' + f = io.StringIO(TestX509Csr.csr_data) with mock.patch(open_name, create=True) as mock_open: - mock_open.return_value = mock.MagicMock(spec=file_class) - m_file = mock_open.return_value.__enter__.return_value - m_file.read.return_value = TestX509Csr.csr_data - csr = signing_request.X509Csr() - csr.from_file("some_path") + mock_open.return_value = f + csr = signing_request.X509Csr.from_file("some_path") name = csr.get_subject() - entries = name.get_entries_by_nid(x509_name.NID_countryName) + entries = name.get_entries_by_oid(x509_name.OID_countryName) self.assertEqual(entries[0].get_value(), "UK") def test_bad_data_throws(self): bad_data = ( - "some bad data is " + u"some bad data is " "EHRlc3RAYW5jaG9yLnRlc3QwTDANBgkqhkiG9w0BAQEFAAM7ADA4AjEA6m") csr = signing_request.X509Csr() @@ -94,49 +98,49 @@ class TestX509Csr(unittest.TestCase): def test_get_subject_countryName(self): name = self.csr.get_subject() - entries = name.get_entries_by_nid(x509_name.NID_countryName) + entries = name.get_entries_by_oid(x509_name.OID_countryName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "countryName") self.assertEqual(entries[0].get_value(), "UK") def test_get_subject_stateOrProvinceName(self): name = self.csr.get_subject() - entries = name.get_entries_by_nid(x509_name.NID_stateOrProvinceName) + entries = name.get_entries_by_oid(x509_name.OID_stateOrProvinceName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "stateOrProvinceName") self.assertEqual(entries[0].get_value(), "Narnia") def test_get_subject_localityName(self): name = self.csr.get_subject() - entries = name.get_entries_by_nid(x509_name.NID_localityName) + entries = name.get_entries_by_oid(x509_name.OID_localityName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "localityName") self.assertEqual(entries[0].get_value(), "Funkytown") def test_get_subject_organizationName(self): name = self.csr.get_subject() - entries = name.get_entries_by_nid(x509_name.NID_organizationName) + entries = name.get_entries_by_oid(x509_name.OID_organizationName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "organizationName") self.assertEqual(entries[0].get_value(), "Anchor Testing") def test_get_subject_organizationUnitName(self): name = self.csr.get_subject() - entries = name.get_entries_by_nid(x509_name.NID_organizationalUnitName) + entries = name.get_entries_by_oid(x509_name.OID_organizationalUnitName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "organizationalUnitName") self.assertEqual(entries[0].get_value(), "testing") def test_get_subject_commonName(self): name = self.csr.get_subject() - entries = name.get_entries_by_nid(x509_name.NID_commonName) + entries = name.get_entries_by_oid(x509_name.OID_commonName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "commonName") self.assertEqual(entries[0].get_value(), "anchor.test") def test_get_subject_emailAddress(self): name = self.csr.get_subject() - entries = name.get_entries_by_nid(x509_name.NID_pkcs9_emailAddress) + entries = name.get_entries_by_oid(x509_name.OID_pkcs9_emailAddress) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "emailAddress") self.assertEqual(entries[0].get_value(), "test@anchor.test") diff --git a/tests/X509/test_x509_name.py b/tests/X509/test_x509_name.py index 0f4ae30..7110384 100644 --- a/tests/X509/test_x509_name.py +++ b/tests/X509/test_x509_name.py @@ -24,18 +24,18 @@ class TestX509Name(unittest.TestCase): def setUp(self): super(TestX509Name, self).setUp() self.name = x509_name.X509Name() - self.name.add_name_entry(x509_name.NID_countryName, + self.name.add_name_entry(x509_name.OID_countryName, "UK") # must be 2 chars - self.name.add_name_entry(x509_name.NID_stateOrProvinceName, "test_ST") - self.name.add_name_entry(x509_name.NID_localityName, "test_L") - self.name.add_name_entry(x509_name.NID_organizationName, "test_O") - self.name.add_name_entry(x509_name.NID_organizationalUnitName, + self.name.add_name_entry(x509_name.OID_stateOrProvinceName, "test_ST") + self.name.add_name_entry(x509_name.OID_localityName, "test_L") + self.name.add_name_entry(x509_name.OID_organizationName, "test_O") + self.name.add_name_entry(x509_name.OID_organizationalUnitName, "test_OU") - self.name.add_name_entry(x509_name.NID_commonName, "test_CN") - self.name.add_name_entry(x509_name.NID_pkcs9_emailAddress, + self.name.add_name_entry(x509_name.OID_commonName, "test_CN") + self.name.add_name_entry(x509_name.OID_pkcs9_emailAddress, "test_Email") - self.name.add_name_entry(x509_name.NID_surname, "test_SN") - self.name.add_name_entry(x509_name.NID_givenName, "test_GN") + self.name.add_name_entry(x509_name.OID_surname, "test_SN") + self.name.add_name_entry(x509_name.OID_givenName, "test_GN") def tearDown(self): pass @@ -48,7 +48,7 @@ class TestX509Name(unittest.TestCase): def test_set_bad_c_throws(self): self.assertRaises(x509_errors.X509Error, self.name.add_name_entry, - x509_name.NID_countryName, "BAD_WRONG") + x509_name.OID_countryName, "BAD_WRONG") def test_name_to_string(self): val = str(self.name) @@ -57,53 +57,53 @@ class TestX509Name(unittest.TestCase): "SN=test_SN/GN=test_GN")) def test_get_countryName(self): - entries = self.name.get_entries_by_nid(x509_name.NID_countryName) + entries = self.name.get_entries_by_oid(x509_name.OID_countryName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "countryName") self.assertEqual(entries[0].get_value(), "UK") def test_get_stateOrProvinceName(self): - entries = self.name.get_entries_by_nid( - x509_name.NID_stateOrProvinceName) + entries = self.name.get_entries_by_oid( + x509_name.OID_stateOrProvinceName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "stateOrProvinceName") self.assertEqual(entries[0].get_value(), "test_ST") def test_get_subject_localityName(self): - entries = self.name.get_entries_by_nid(x509_name.NID_localityName) + entries = self.name.get_entries_by_oid(x509_name.OID_localityName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "localityName") self.assertEqual(entries[0].get_value(), "test_L") def test_get_organizationName(self): - entries = self.name.get_entries_by_nid(x509_name.NID_organizationName) + entries = self.name.get_entries_by_oid(x509_name.OID_organizationName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "organizationName") self.assertEqual(entries[0].get_value(), "test_O") def test_get_organizationUnitName(self): - entries = self.name.get_entries_by_nid( - x509_name.NID_organizationalUnitName) + entries = self.name.get_entries_by_oid( + x509_name.OID_organizationalUnitName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "organizationalUnitName") self.assertEqual(entries[0].get_value(), "test_OU") def test_get_commonName(self): - entries = self.name.get_entries_by_nid(x509_name.NID_commonName) + entries = self.name.get_entries_by_oid(x509_name.OID_commonName) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "commonName") self.assertEqual(entries[0].get_value(), "test_CN") def test_get_emailAddress(self): - entries = self.name.get_entries_by_nid( - x509_name.NID_pkcs9_emailAddress) + entries = self.name.get_entries_by_oid( + x509_name.OID_pkcs9_emailAddress) self.assertEqual(len(entries), 1) self.assertEqual(entries[0].get_name(), "emailAddress") self.assertEqual(entries[0].get_value(), "test_Email") def test_entry_to_string(self): - entries = self.name.get_entries_by_nid( - x509_name.NID_pkcs9_emailAddress) + entries = self.name.get_entries_by_oid( + x509_name.OID_pkcs9_emailAddress) self.assertEqual(len(entries), 1) self.assertEqual(str(entries[0]), "emailAddress: test_Email") diff --git a/tests/test_certificate_ops.py b/tests/test_certificate_ops.py index 852a9e5..29f8221 100644 --- a/tests/test_certificate_ops.py +++ b/tests/test_certificate_ops.py @@ -29,7 +29,7 @@ class CertificateOpsTests(unittest.TestCase): def setUp(self): # This is a CSR with CN=anchor-test.example.com self.expected_cn = "anchor-test.example.com" - self.csr = textwrap.dedent(""" + self.csr = textwrap.dedent(u""" -----BEGIN CERTIFICATE REQUEST----- MIIEsDCCApgCAQAwazELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWEx FjAUBgNVBAcTDU1vdW50YWluIFZpZXcxDTALBgNVBAoTBEFjbWUxIDAeBgNVBAMT @@ -67,16 +67,16 @@ class CertificateOpsTests(unittest.TestCase): """Test basic success path for parse_csr.""" result = certificate_ops.parse_csr(self.csr, 'pem') subject = result.get_subject() - actual_cn = subject.get_entries_by_nid( - x509_name.NID_commonName)[0].get_value() + actual_cn = subject.get_entries_by_oid( + x509_name.OID_commonName)[0].get_value() self.assertEqual(actual_cn, self.expected_cn) def test_parse_csr_success2(self): """Test basic success path for parse_csr.""" result = certificate_ops.parse_csr(self.csr, 'PEM') subject = result.get_subject() - actual_cn = subject.get_entries_by_nid( - x509_name.NID_commonName)[0].get_value() + actual_cn = subject.get_entries_by_oid( + x509_name.OID_commonName)[0].get_value() self.assertEqual(actual_cn, self.expected_cn) def test_parse_csr_fail1(self): diff --git a/tests/test_functional.py b/tests/test_functional.py index ce7ce1c..af90d3f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -58,7 +58,7 @@ class TestFunctional(unittest.TestCase): } """ - csr_good = textwrap.dedent(""" + csr_good = textwrap.dedent(u""" -----BEGIN CERTIFICATE REQUEST----- MIIEDzCCAncCAQAwcjELMAkGA1UEBhMCR0IxEzARBgNVBAgTCkNhbGlmb3JuaWEx FjAUBgNVBAcTDVNhbiBGcmFuY3NpY28xDTALBgNVBAoTBE9TU0cxDTALBgNVBAsT @@ -84,7 +84,7 @@ class TestFunctional(unittest.TestCase): tR7XqQGqJKca/vRTfJ+zIAxMEeH1N9Lx7YBO6VdVja+yG1E= -----END CERTIFICATE REQUEST-----""") - csr_bad = textwrap.dedent(""" + csr_bad = textwrap.dedent(u""" -----BEGIN CERTIFICATE REQUEST----- MIIBWTCCARMCAQAwgZQxCzAJBgNVBAYTAlVLMQ8wDQYDVQQIEwZOYXJuaWExEjAQ BgNVBAcTCUZ1bmt5dG93bjEXMBUGA1UEChMOQW5jaG9yIFRlc3RpbmcxEDAOBgNV @@ -149,8 +149,7 @@ class TestFunctional(unittest.TestCase): resp = self.app.post('/sign', data, expect_errors=False) self.assertEqual(200, resp.status_int) - cert = X509_cert.X509Certificate() - cert.from_buffer(resp.text) + cert = X509_cert.X509Certificate.from_buffer(resp.text) # make sure the cert is what we asked for self.assertEqual(("/C=GB/ST=California/L=San Francsico/O=OSSG" diff --git a/tests/validators/test_base_validation_functions.py b/tests/validators/test_base_validation_functions.py index 433d233..d412b8f 100644 --- a/tests/validators/test_base_validation_functions.py +++ b/tests/validators/test_base_validation_functions.py @@ -24,7 +24,7 @@ from anchor.X509 import signing_request class TestBaseValidators(unittest.TestCase): - csr_data_with_cn = textwrap.dedent(""" + csr_data_with_cn = textwrap.dedent(u""" -----BEGIN CERTIFICATE REQUEST----- MIIDBTCCAe0CAQAwgb8xCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlh MRYwFAYDVQQHEw1TYW4gRnJhbmNpc2NvMSEwHwYDVQQKExhPcGVuU3RhY2sgU2Vj @@ -51,7 +51,7 @@ class TestBaseValidators(unittest.TestCase): CN=ossg.test.com/emailAddress=openstack-security@lists.openstack.org """ - csr_data_without_cn = textwrap.dedent(""" + csr_data_without_cn = textwrap.dedent(u""" -----BEGIN CERTIFICATE REQUEST----- MIIC7TCCAdUCAQAwgacxCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlh MRYwFAYDVQQHDA1TYW4gRnJhbmNpc2NvMSEwHwYDVQQKDBhPcGVuU3RhY2sgU2Vj @@ -79,8 +79,8 @@ class TestBaseValidators(unittest.TestCase): def setUp(self): super(TestBaseValidators, self).setUp() - self.csr = signing_request.X509Csr() - self.csr.from_buffer(TestBaseValidators.csr_data_with_cn) + self.csr = signing_request.X509Csr.from_buffer( + TestBaseValidators.csr_data_with_cn) def tearDown(self): super(TestBaseValidators, self).tearDown() @@ -89,7 +89,8 @@ class TestBaseValidators(unittest.TestCase): name = validators.csr_get_cn(self.csr) self.assertEqual(name, "ossg.test.com") - self.csr.from_buffer(TestBaseValidators.csr_data_without_cn) + self.csr = signing_request.X509Csr.from_buffer( + TestBaseValidators.csr_data_without_cn) with self.assertRaises(validators.ValidationError): validators.csr_get_cn(self.csr) diff --git a/tests/validators/test_callable_validators.py b/tests/validators/test_callable_validators.py index fdddcce..2faa904 100644 --- a/tests/validators/test_callable_validators.py +++ b/tests/validators/test_callable_validators.py @@ -20,7 +20,9 @@ import mock import netaddr from anchor import validators +from anchor.X509 import extension as x509_ext from anchor.X509 import name as x509_name +from anchor.X509 import signing_request as x509_csr class TestValidators(unittest.TestCase): @@ -45,262 +47,174 @@ class TestValidators(unittest.TestCase): 'example.com', [])) def test_common_name_with_two_CN(self): - ext_mock = mock.MagicMock() - ext_mock.get_name.return_value = "subjectAltName" - - csr_config = { - 'get_extensions.return_value': [ext_mock], - 'get_subject.return_value.get_entries_by_nid.return_value': - ['dummy_value', 'dummy_value'], - } - csr_mock = mock.MagicMock(**csr_config) + csr = x509_csr.X509Csr() + name = csr.get_subject() + name.add_name_entry(x509_name.OID_commonName, "dummy_value") + name.add_name_entry(x509_name.OID_commonName, "dummy_value") with self.assertRaises(validators.ValidationError) as e: validators.common_name( - csr=csr_mock, + csr=csr, allowed_domains=[], allowed_networks=[]) self.assertEqual("Too many CNs in the request", str(e.exception)) def test_common_name_no_CN(self): - csr_config = { - 'get_subject.return_value.__len__.return_value': 0, - 'get_subject.return_value.get_entries_by_nid.return_value': - [] - } - csr_mock = mock.MagicMock(**csr_config) + csr = x509_csr.X509Csr() with self.assertRaises(validators.ValidationError) as e: validators.common_name( - csr=csr_mock, + csr=csr, allowed_domains=[], allowed_networks=[]) self.assertEqual("Alt subjects have to exist if the main subject" " doesn't", str(e.exception)) def test_common_name_good_CN(self): - cn_mock = mock.MagicMock() - cn_mock.get_value.return_value = 'master.test.com' - - csr_config = { - 'get_subject.return_value.__len__.return_value': 1, - 'get_subject.return_value.get_entries_by_nid.return_value': - [cn_mock], - } - csr_mock = mock.MagicMock(**csr_config) + csr = x509_csr.X509Csr() + name = csr.get_subject() + name.add_name_entry(x509_name.OID_commonName, "master.test.com") self.assertEqual( None, validators.common_name( - csr=csr_mock, + csr=csr, allowed_domains=['.test.com'], ) ) def test_common_name_bad_CN(self): - name = x509_name.X509Name() - name.add_name_entry(x509_name.NID_commonName, 'test.baddomain.com') - - csr_mock = mock.MagicMock() - csr_mock.get_subject.return_value = name + csr = x509_csr.X509Csr() + name = csr.get_subject() + name.add_name_entry(x509_name.OID_commonName, 'test.baddomain.com') with self.assertRaises(validators.ValidationError) as e: validators.common_name( - csr=csr_mock, + csr=csr, allowed_domains=['.test.com']) self.assertEqual("Domain 'test.baddomain.com' not allowed (does not " "match known domains)", str(e.exception)) def test_common_name_ip_good(self): - name = x509_name.X509Name() - name.add_name_entry(x509_name.NID_commonName, '10.1.1.1') - - csr_mock = mock.MagicMock() - csr_mock.get_subject.return_value = name + csr = x509_csr.X509Csr() + name = csr.get_subject() + name.add_name_entry(x509_name.OID_commonName, '10.1.1.1') self.assertEqual( None, validators.common_name( - csr=csr_mock, + csr=csr, allowed_domains=['.test.com'], allowed_networks=['10/8'] ) ) def test_common_name_ip_bad(self): - name = x509_name.X509Name() - name.add_name_entry(x509_name.NID_commonName, '15.1.1.1') - - csr_mock = mock.MagicMock() - csr_mock.get_subject.return_value = name + csr = x509_csr.X509Csr() + name = csr.get_subject() + name.add_name_entry(x509_name.OID_commonName, '15.1.1.1') with self.assertRaises(validators.ValidationError) as e: validators.common_name( - csr=csr_mock, + csr=csr, allowed_domains=['.test.com'], allowed_networks=['10/8']) self.assertEqual("Address '15.1.1.1' not allowed (does not " "match known networks)", str(e.exception)) def test_alternative_names_good_domain(self): - ext_mock = mock.MagicMock() - ext_mock.get_value.return_value = 'DNS:master.test.com' - ext_mock.get_name.return_value = 'subjectAltName' + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionSubjectAltName() + ext.add_dns_id('master.test.com') + csr.add_extension(ext) - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] self.assertEqual( None, validators.alternative_names( - csr=csr_mock, + csr=csr, allowed_domains=['.test.com'], ) ) def test_alternative_names_bad_domain(self): - ext_mock = mock.MagicMock() - ext_mock.get_value.return_value = 'DNS:test.baddomain.com' - ext_mock.get_name.return_value = 'subjectAltName' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionSubjectAltName() + ext.add_dns_id('test.baddomain.com') + csr.add_extension(ext) with self.assertRaises(validators.ValidationError) as e: validators.alternative_names( - csr=csr_mock, + csr=csr, allowed_domains=['.test.com']) self.assertEqual("Domain 'test.baddomain.com' not allowed (doesn't " "match known domains)", str(e.exception)) - def test_alternative_names_ext(self): - ext_mock = mock.MagicMock() - ext_mock.get_value.return_value = 'BAD,10.1.1.1' - ext_mock.get_name.return_value = 'subjectAltName' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] - - with self.assertRaises(validators.ValidationError) as e: - validators.alternative_names( - csr=csr_mock, - allowed_domains=['.test.com']) - self.assertEqual("Alt name should have 2 parts, but found: 'BAD'", - str(e.exception)) - def test_alternative_names_ip_good(self): - ext_mock = mock.MagicMock() - ext_mock.get_value.return_value = 'IP Address:10.1.1.1' - ext_mock.get_name.return_value = 'subjectAltName' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionSubjectAltName() + ext.add_ip(netaddr.IPAddress('10.1.1.1')) + csr.add_extension(ext) self.assertEqual( None, validators.alternative_names_ip( - csr=csr_mock, + csr=csr, allowed_domains=['.test.com'], allowed_networks=['10/8'] ) ) def test_alternative_names_ip_bad(self): - - ext_mock = mock.MagicMock() - ext_mock.get_value.return_value = 'IP Address:10.1.1.1' - ext_mock.get_name.return_value = 'subjectAltName' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionSubjectAltName() + ext.add_ip(netaddr.IPAddress('10.1.1.1')) + csr.add_extension(ext) with self.assertRaises(validators.ValidationError) as e: validators.alternative_names_ip( - csr=csr_mock, + csr=csr, allowed_domains=['.test.com'], allowed_networks=['99/8']) - self.assertEqual("Address '10.1.1.1' not allowed (doesn't match known " + self.assertEqual("IP '10.1.1.1' not allowed (doesn't match known " "networks)", str(e.exception)) def test_alternative_names_ip_bad_domain(self): - ext_mock = mock.MagicMock() - ext_mock.get_value.return_value = 'DNS:test.baddomain.com' - ext_mock.get_name.return_value = 'subjectAltName' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionSubjectAltName() + ext.add_dns_id('test.baddomain.com') + csr.add_extension(ext) with self.assertRaises(validators.ValidationError) as e: validators.alternative_names_ip( - csr=csr_mock, + csr=csr, allowed_domains=['.test.com']) self.assertEqual("Domain 'test.baddomain.com' not allowed (doesn't " "match known domains)", str(e.exception)) - def test_alternative_names_ip_ext(self): - ext_mock = mock.MagicMock() - ext_mock.get_value.return_value = 'BAD,10.1.1.1' - ext_mock.get_name.return_value = 'subjectAltName' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] - - with self.assertRaises(validators.ValidationError) as e: - validators.alternative_names_ip( - csr=csr_mock, - allowed_domains=['.test.com']) - self.assertEqual("Alt name should have 2 parts, but found: 'BAD'", - str(e.exception)) - - def test_alternative_names_ip_bad_ext(self): - ext_mock = mock.MagicMock() - ext_mock.get_value.return_value = 'BAD:VALUE' - ext_mock.get_name.return_value = 'subjectAltName' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] - - with self.assertRaises(validators.ValidationError) as e: - validators.alternative_names_ip( - csr=csr_mock, - allowed_domains=['.test.com'], - allowed_networks=['99/8']) - self.assertEqual("Alt name 'VALUE' has unexpected type 'BAD'", - str(e.exception)) - def test_server_group_no_prefix1(self): - cn_mock = mock.MagicMock() - cn_mock.get_value.return_value = 'master.test.com' - - csr_config = { - 'get_subject.return_value.get_entries_by_nid.return_value': - [cn_mock], - } - csr_mock = mock.MagicMock(**csr_config) + csr = x509_csr.X509Csr() + name = csr.get_subject() + name.add_name_entry(x509_name.OID_commonName, "master.test.com") self.assertEqual( None, validators.server_group( auth_result=None, - csr=csr_mock, + csr=csr, group_prefixes={} ) ) def test_server_group_no_prefix2(self): - cn_mock = mock.MagicMock() - cn_mock.get_value.return_value = 'nv_master.test.com' - - csr_config = { - 'get_subject.return_value.get_entries_by_nid.return_value': - [cn_mock], - } - csr_mock = mock.MagicMock(**csr_config) + csr = x509_csr.X509Csr() + name = csr.get_subject() + name.add_name_entry(x509_name.OID_commonName, "nv_master.test.com") self.assertEqual( None, validators.server_group( auth_result=None, - csr=csr_mock, + csr=csr, group_prefixes={} ) ) @@ -310,20 +224,15 @@ class TestValidators(unittest.TestCase): auth_result = mock.Mock() auth_result.groups = ['nova'] - cn_mock = mock.MagicMock() - cn_mock.get_value.return_value = 'nv_master.test.com' - - csr_config = { - 'get_subject.return_value.get_entries_by_nid.return_value': - [cn_mock], - } - csr_mock = mock.MagicMock(**csr_config) + csr = x509_csr.X509Csr() + name = csr.get_subject() + name.add_name_entry(x509_name.OID_commonName, "nv_master.test.com") self.assertEqual( None, validators.server_group( auth_result=auth_result, - csr=csr_mock, + csr=csr, group_prefixes={'nv': 'nova', 'sw': 'swift'} ) ) @@ -332,50 +241,41 @@ class TestValidators(unittest.TestCase): auth_result = mock.Mock() auth_result.groups = ['glance'] - cn_mock = mock.MagicMock() - cn_mock.get_value.return_value = 'nv-master.test.com' - - csr_config = { - 'get_subject.return_value.get_entries_by_nid.return_value': - [cn_mock], - } - csr_mock = mock.MagicMock(**csr_config) + csr = x509_csr.X509Csr() + name = csr.get_subject() + name.add_name_entry(x509_name.OID_commonName, "nv-master.test.com") with self.assertRaises(validators.ValidationError) as e: validators.server_group( auth_result=auth_result, - csr=csr_mock, + csr=csr, group_prefixes={'nv': 'nova', 'sw': 'swift'}) self.assertEqual("Server prefix doesn't match user groups", str(e.exception)) def test_extensions_bad(self): - ext_mock = mock.MagicMock() - ext_mock.get_name.return_value = 'BAD' - ext_mock.get_value.return_value = 'BAD' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionKeyUsage() + ext.set_usage('keyCertSign', True) + csr.add_extension(ext) with self.assertRaises(validators.ValidationError) as e: validators.extensions( - csr=csr_mock, - allowed_extensions=['GOOD-1', 'GOOD-2']) - self.assertEqual("Extension 'BAD' not allowed", str(e.exception)) + csr=csr, + allowed_extensions=['basicConstraints', 'nameConstraints']) + self.assertEqual("Extension 'keyUsage' not allowed", str(e.exception)) def test_extensions_good(self): - ext_mock = mock.MagicMock() - ext_mock.get_name.return_value = 'GOOD-1' - ext_mock.get_value.return_value = 'GOOD-1' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionKeyUsage() + ext.set_usage('keyCertSign', True) + csr.add_extension(ext) self.assertEqual( None, validators.extensions( - csr=csr_mock, - allowed_extensions=['GOOD-1', 'GOOD-2'] + csr=csr, + allowed_extensions=['basicConstraints', 'keyUsage'] ) ) @@ -384,204 +284,158 @@ class TestValidators(unittest.TestCase): 'Non Repudiation', 'Key Encipherment'] - ext_mock = mock.MagicMock() - ext_mock.get_name.return_value = 'keyUsage' - ext_mock.get_value.return_value = 'Domination' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionKeyUsage() + ext.set_usage('keyCertSign', True) + csr.add_extension(ext) with self.assertRaises(validators.ValidationError) as e: validators.key_usage( - csr=csr_mock, + csr=csr, allowed_usage=allowed_usage) self.assertEqual("Found some not allowed key usages: " - "Domination", str(e.exception)) + "keyCertSign", str(e.exception)) def test_key_usage_good(self): allowed_usage = ['Digital Signature', 'Non Repudiation', 'Key Encipherment'] - ext_mock = mock.MagicMock() - ext_mock.get_name.return_value = 'keyUsage' - ext_mock.get_value.return_value = 'Key Encipherment, Digital Signature' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionKeyUsage() + ext.set_usage('keyEncipherment', True) + ext.set_usage('digitalSignature', True) + csr.add_extension(ext) self.assertEqual( None, validators.key_usage( - csr=csr_mock, + csr=csr, allowed_usage=allowed_usage ) ) def test_ca_status_good1(self): - ext_mock = mock.MagicMock() - ext_mock.get_name.return_value = 'basicConstraints' - ext_mock.get_value.return_value = 'CA:TRUE' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionBasicConstraints() + ext.set_ca(True) + csr.add_extension(ext) self.assertEqual( None, validators.ca_status( - csr=csr_mock, + csr=csr, ca_requested=True ) ) def test_ca_status_good2(self): - ext_mock = mock.MagicMock() - ext_mock.get_name.return_value = 'basicConstraints' - ext_mock.get_value.return_value = 'CA:FALSE' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionBasicConstraints() + ext.set_ca(False) + csr.add_extension(ext) self.assertEqual( None, validators.ca_status( - csr=csr_mock, + csr=csr, ca_requested=False ) ) - def test_ca_status_bad(self): - ext_mock = mock.MagicMock() - ext_mock.get_name.return_value = 'basicConstraints' - ext_mock.get_value.return_value = 'CA:FALSE' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] + def test_ca_status_forbidden(self): + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionBasicConstraints() + ext.set_ca(True) + csr.add_extension(ext) with self.assertRaises(validators.ValidationError) as e: validators.ca_status( - csr=csr_mock, - ca_requested=True) - self.assertEqual("Invalid CA status, 'CA:FALSE' requested", + csr=csr, + ca_requested=False) + self.assertEqual("CA status requested, but not allowed", str(e.exception)) - def test_ca_status_bad_format1(self): - ext_mock = mock.MagicMock() - ext_mock.get_name.return_value = 'basicConstraints' - ext_mock.get_value.return_value = 'CA~FALSE' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] + def test_ca_status_bad(self): + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionBasicConstraints() + ext.set_ca(False) + csr.add_extension(ext) with self.assertRaises(validators.ValidationError) as e: validators.ca_status( - csr=csr_mock, - ca_requested=False) - self.assertEqual("Invalid basic constraints flag", str(e.exception)) - - def test_ca_status_bad_format2(self): - ext_mock = mock.MagicMock() - ext_mock.get_name.return_value = 'basicConstraints' - ext_mock.get_value.return_value = 'CA:FALSE:DERP' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] - - with self.assertRaises(validators.ValidationError) as e: - validators.ca_status( - csr=csr_mock, - ca_requested=False) - self.assertEqual("Invalid basic constraints flag", str(e.exception)) + csr=csr, + ca_requested=True) + self.assertEqual("CA flags required", + str(e.exception)) def test_ca_status_pathlen(self): - ext_mock = mock.MagicMock() - ext_mock.get_name.return_value = 'basicConstraints' - ext_mock.get_value.return_value = 'pathlen:somthing' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionBasicConstraints() + ext.set_path_len_constraint(1) + csr.add_extension(ext) self.assertEqual( None, validators.ca_status( - csr=csr_mock, + csr=csr, ca_requested=False ) ) - def test_ca_status_bad_value(self): - ext_mock = mock.MagicMock() - ext_mock.get_name.return_value = 'basicConstraints' - ext_mock.get_value.return_value = 'BAD:VALUE' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] - - with self.assertRaises(validators.ValidationError) as e: - validators.ca_status( - csr=csr_mock, - ca_requested=False) - self.assertEqual("Invalid basic constraints option", str(e.exception)) - def test_ca_status_key_usage_bad1(self): - ext_mock = mock.MagicMock() - ext_mock.get_name.return_value = 'keyUsage' - ext_mock.get_value.return_value = 'Certificate Sign' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionKeyUsage() + ext.set_usage('keyCertSign', True) + csr.add_extension(ext) with self.assertRaises(validators.ValidationError) as e: validators.ca_status( - csr=csr_mock, + csr=csr, ca_requested=False) self.assertEqual("Key usage doesn't match requested CA status " "(keyCertSign/cRLSign: True/False)", str(e.exception)) def test_ca_status_key_usage_good1(self): - ext_mock = mock.MagicMock() - ext_mock.get_name.return_value = 'keyUsage' - ext_mock.get_value.return_value = 'Certificate Sign' + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionKeyUsage() + ext.set_usage('keyCertSign', True) + csr.add_extension(ext) - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] - - with self.assertRaises(validators.ValidationError) as e: + self.assertEqual( + None, validators.ca_status( - csr=csr_mock, - ca_requested=True) - self.assertEqual("Key usage doesn't match requested CA status " - "(keyCertSign/cRLSign: True/False)", str(e.exception)) + csr=csr, + ca_requested=True + ) + ) def test_ca_status_key_usage_bad2(self): - ext_mock = mock.MagicMock() - ext_mock.get_name.return_value = 'keyUsage' - ext_mock.get_value.return_value = 'CRL Sign' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionKeyUsage() + ext.set_usage('cRLSign', True) + csr.add_extension(ext) with self.assertRaises(validators.ValidationError) as e: validators.ca_status( - csr=csr_mock, + csr=csr, ca_requested=False) self.assertEqual("Key usage doesn't match requested CA status " "(keyCertSign/cRLSign: False/True)", str(e.exception)) def test_ca_status_key_usage_good2(self): - ext_mock = mock.MagicMock() - ext_mock.get_name.return_value = 'keyUsage' - ext_mock.get_value.return_value = 'CRL Sign' + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionKeyUsage() + ext.set_usage('cRLSign', True) + csr.add_extension(ext) - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] - - with self.assertRaises(validators.ValidationError) as e: + self.assertEqual( + None, validators.ca_status( - csr=csr_mock, - ca_requested=True) - self.assertEqual("Key usage doesn't match requested CA status " - "(keyCertSign/cRLSign: False/True)", str(e.exception)) + csr=csr, + ca_requested=True + ) + ) def test_source_cidrs_good(self): request = mock.Mock(client_addr='127.0.0.1') @@ -612,99 +466,65 @@ class TestValidators(unittest.TestCase): str(e.exception)) def test_blacklist_names_good(self): - ext_mock = mock.MagicMock() - ext_mock.get_value.return_value = 'DNS:blah.good' - ext_mock.get_name.return_value = 'subjectAltName' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionSubjectAltName() + ext.add_dns_id('blah.good') + csr.add_extension(ext) self.assertEqual( None, validators.blacklist_names( - csr=csr_mock, + csr=csr, domains=['.bad'], ) ) def test_blacklist_names_bad(self): - ext_mock = mock.MagicMock() - ext_mock.get_value.return_value = 'DNS:blah.bad' - ext_mock.get_name.return_value = 'subjectAltName' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionSubjectAltName() + ext.add_dns_id('blah.bad') + csr.add_extension(ext) with self.assertRaises(validators.ValidationError): validators.blacklist_names( - csr=csr_mock, + csr=csr, domains=['.bad'], ) def test_blacklist_names_bad_cn(self): - cn_mock = mock.MagicMock() - cn_mock.get_value.return_value = 'blah.bad' - - csr_config = { - 'get_subject.return_value.get_entries_by_nid.return_value': - [cn_mock], - } - csr_mock = mock.MagicMock(**csr_config) + csr = x509_csr.X509Csr() + name = csr.get_subject() + name.add_name_entry(x509_name.OID_commonName, "blah.bad") with self.assertRaises(validators.ValidationError): validators.blacklist_names( - csr=csr_mock, + csr=csr, domains=['.bad'], ) def test_blacklist_names_mix(self): - ext1_mock = mock.MagicMock() - ext1_mock.get_value.return_value = 'DNS:blah.good' - ext1_mock.get_name.return_value = 'subjectAltName' - - ext2_mock = mock.MagicMock() - ext2_mock.get_value.return_value = 'DNS:blah.bad' - ext2_mock.get_name.return_value = 'subjectAltName' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext1_mock, ext2_mock] + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionSubjectAltName() + ext.add_dns_id('blah.bad') + ext.add_dns_id('blah.good') + csr.add_extension(ext) with self.assertRaises(validators.ValidationError): validators.blacklist_names( - csr=csr_mock, + csr=csr, domains=['.bad'], ) - def test_blacklist_names_ignore_unknown(self): - # only validate the DNS type - other types may look like domains - # by accident - ext_mock = mock.MagicMock() - ext_mock.get_value.return_value = 'RANDOM_TYPE:random.bad' - ext_mock.get_name.return_value = 'subjectAltName' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] - - self.assertEqual( - None, - validators.blacklist_names( - csr=csr_mock, - domains=['.bad'], - ) - ) - def test_blacklist_names_empty_list(self): # empty blacklist should pass everything through - ext_mock = mock.MagicMock() - ext_mock.get_value.return_value = 'DNS:some.name' - ext_mock.get_name.return_value = 'subjectAltName' - - csr_mock = mock.MagicMock() - csr_mock.get_extensions.return_value = [ext_mock] + csr = x509_csr.X509Csr() + ext = x509_ext.X509ExtensionSubjectAltName() + ext.add_dns_id('blah.good') + csr.add_extension(ext) self.assertEqual( None, validators.blacklist_names( - csr=csr_mock, + csr=csr, ) )