Integrate PyASN1 for certificate operations

Instead of relying on openssl code for certificate parsing, use the
ASN.1 representation directly. All previous features are supported. Not
all the extensions are full parsed yet, but the code doesn't require
them for now.

The code makes accessing and modifying the certificate structure simpler
and requires less error checking than the original version. The code
leaves few TODOs, but nothing that destroys previous behaviour.

It still uses the cryptography.io backend for loading keys and producing
signatures for the certificates.

Implements: blueprint direct-asn1
Change-Id: Ic555d3d056ca8da7016e2d8b434506cf214d06a1
This commit is contained in:
Stanisław Pitucha 2015-07-22 14:03:02 +10:00
parent 76c76e6ac5
commit ef390f5f54
19 changed files with 1340 additions and 1125 deletions

View File

@ -13,127 +13,108 @@
from __future__ import absolute_import from __future__ import absolute_import
from cryptography.hazmat.backends.openssl import backend import base64
import binascii
import io
from cryptography.hazmat import backends as cio_backends
from cryptography.hazmat.primitives.asymmetric import dsa
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import hashes
from pyasn1.codec.der import decoder
from pyasn1.codec.der import encoder
from pyasn1.type import univ as asn1_univ
from pyasn1_modules import pem
from pyasn1_modules import rfc2459 # X509v3
from anchor.X509 import errors from anchor.X509 import errors
from anchor.X509 import message_digest from anchor.X509 import extension
from anchor.X509 import name from anchor.X509 import name
from anchor.X509 import utils from anchor.X509 import utils
SIGNING_ALGORITHMS = {
('RSA', 'MD5'): rfc2459.md5WithRSAEncryption,
('RSA', 'SHA1'): rfc2459.sha1WithRSAEncryption,
('RSA', 'SHA224'): asn1_univ.ObjectIdentifier('1.2.840.113549.1.1.14'),
('RSA', 'SHA256'): asn1_univ.ObjectIdentifier('1.2.840.113549.1.1.11'),
('RSA', 'SHA384'): asn1_univ.ObjectIdentifier('1.2.840.113549.1.1.12'),
('RSA', 'SHA512'): asn1_univ.ObjectIdentifier('1.2.840.113549.1.1.13'),
('DSA', 'SHA1'): rfc2459.id_dsa_with_sha1,
('DSA', 'SHA224'): asn1_univ.ObjectIdentifier('2.16.840.1.101.3.4.3.1'),
('DSA', 'SHA256'): asn1_univ.ObjectIdentifier('2.16.840.1.101.3.4.3.2'),
}
class X509CertificateError(errors.X509Error): class X509CertificateError(errors.X509Error):
"""Specific error for X509 certificate operations.""" """Specific error for X509 certificate operations."""
def __init__(self, what): pass
super(X509CertificateError, self).__init__(what)
class X509Extension(object):
"""An X509 V3 Certificate extension."""
def __init__(self, ext):
self._lib = backend._lib
self._ffi = backend._ffi
self._ext = ext
def __str__(self):
return "%s %s" % (self.get_name(), self.get_value())
def get_name(self):
"""Get the extension name as a python string."""
ext_obj = self._lib.X509_EXTENSION_get_object(self._ext)
ext_nid = self._lib.OBJ_obj2nid(ext_obj)
ext_name_str = self._lib.OBJ_nid2sn(ext_nid)
return self._ffi.string(ext_name_str).decode('ascii')
def get_value(self):
"""Get the extension value as a python string."""
bio = self._lib.BIO_new(self._lib.BIO_s_mem())
bio = self._ffi.gc(bio, self._lib.BIO_free)
self._lib.X509V3_EXT_print(bio, self._ext, 0, 0)
size = 1024
data = self._ffi.new("char[]", size)
self._lib.BIO_gets(bio, data, size)
return self._ffi.string(data).decode('ascii')
class X509Certificate(object): class X509Certificate(object):
"""X509 certificate class.""" """X509 certificate class."""
def __init__(self): def __init__(self, certificate=None):
self._lib = backend._lib if certificate is None:
self._ffi = backend._ffi self._cert = rfc2459.Certificate()
certObj = self._lib.X509_new() self._cert['tbsCertificate'] = rfc2459.TBSCertificate()
if certObj == self._ffi.NULL: else:
raise X509CertificateError("Could not create X509 certificate " self._cert = certificate
"object") # pragma: no cover
self._certObj = certObj @staticmethod
def from_open_file(f):
try:
der_content = pem.readPemFromFile(f)
certificate = decoder.decode(der_content,
asn1Spec=rfc2459.Certificate())[0]
return X509Certificate(certificate)
except Exception:
raise X509CertificateError("Could not read X509 certificate from "
"PEM data.")
def __del__(self): @staticmethod
if getattr(self, '_certObj', None): def from_buffer(data):
self._lib.X509_free(self._certObj)
def from_buffer(self, data):
"""Build this X509 object from a data buffer in memory. """Build this X509 object from a data buffer in memory.
:param data: A data buffer :param data: A data buffer
""" """
if type(data) != bytes: return X509Certificate.from_open_file(io.StringIO(data))
data = data.encode('ascii')
bio = backend._bytes_to_bio(data)
# NOTE(tkelsey): some versions of OpenSSL dont re-use the cert object @staticmethod
# properly, so free it and use the new one def from_file(path):
#
certObj = self._lib.PEM_read_bio_X509(bio[0],
self._ffi.NULL,
self._ffi.NULL,
self._ffi.NULL)
if certObj == self._ffi.NULL:
raise X509CertificateError("Could not read X509 certificate from "
"PEM data.")
self._lib.X509_free(self._certObj)
self._certObj = certObj
def from_file(self, path):
"""Build this X509 certificate object from a data file on disk. """Build this X509 certificate object from a data file on disk.
:param path: A data buffer :param path: A data buffer
""" """
data = None with open(path, 'r') as f:
with open(path, 'rb') as f: return X509Certificate.from_open_file(f)
data = f.read()
self.from_buffer(data)
def as_pem(self): def as_pem(self):
"""Serialise this X509 certificate object as PEM string.""" """Serialise this X509 certificate object as PEM string."""
raw_bio = self._lib.BIO_new(self._lib.BIO_s_mem()) header = '-----BEGIN CERTIFICATE-----'
bio = self._ffi.gc(raw_bio, self._lib.BIO_free) footer = '-----END CERTIFICATE-----'
ret = self._lib.PEM_write_bio_X509(bio, self._certObj) der_cert = encoder.encode(self._cert)
b64_encoder = (base64.encodestring if str is bytes else
if ret == 0: base64.encodebytes)
raise X509CertificateError("Could not write X509 certificate " b64_cert = b64_encoder(der_cert).decode('ascii')
"as PEM data.") # pragma: no cover return "%s\n%s%s\n" % (header, b64_cert, footer)
buf = self._ffi.new("char**")
pem_len = self._lib.BIO_get_mem_data(bio, buf)
pem = self._ffi.string(buf[0], pem_len)
return pem
def set_version(self, v): def set_version(self, v):
"""Set the version of this X509 certificate object. """Set the version of this X509 certificate object.
:param v: The version :param v: The version
""" """
ret = self._lib.X509_set_version(self._certObj, v) self._cert['tbsCertificate']['version'] = v
if ret == 0:
raise X509CertificateError("Could not set X509 certificate "
"version.") # pragma: no cover
def get_version(self): def get_version(self):
"""Get the version of this X509 certificate object.""" """Get the version of this X509 certificate object."""
return self._lib.X509_get_version(self._certObj) return self._cert['tbsCertificate']['version']
def get_validity(self):
if self._cert['tbsCertificate']['validity'] is None:
self._cert['tbsCertificate']['validity'] = None
return self._cert['tbsCertificate']['validity']
def set_not_before(self, t): def set_not_before(self, t):
"""Set the 'not before' date field. """Set the 'not before' date field.
@ -141,15 +122,13 @@ class X509Certificate(object):
:param t: time in seconds since the epoch :param t: time in seconds since the epoch
""" """
asn1_time = utils.timestamp_to_asn1_time(t) asn1_time = utils.timestamp_to_asn1_time(t)
ret = self._lib.X509_set_notBefore(self._certObj, asn1_time) validity = self.get_validity()
self._lib.ASN1_TIME_free(asn1_time) validity['notBefore'] = asn1_time
if ret == 0:
raise X509CertificateError("Could not set X509 certificate "
"not before time.") # pragma: no cover
def get_not_before(self): def get_not_before(self):
"""Get the 'not before' date field as seconds since the epoch.""" """Get the 'not before' date field as seconds since the epoch."""
not_before = self._lib.X509_get_notBefore(self._certObj) validity = self.get_validity()
not_before = validity['notBefore']
return utils.asn1_time_to_timestamp(not_before) return utils.asn1_time_to_timestamp(not_before)
def set_not_after(self, t): def set_not_after(self, t):
@ -158,37 +137,28 @@ class X509Certificate(object):
:param t: time in seconds since the epoch :param t: time in seconds since the epoch
""" """
asn1_time = utils.timestamp_to_asn1_time(t) asn1_time = utils.timestamp_to_asn1_time(t)
ret = self._lib.X509_set_notAfter(self._certObj, asn1_time) validity = self.get_validity()
self._lib.ASN1_TIME_free(asn1_time) validity['notAfter'] = asn1_time
if ret == 0:
raise X509CertificateError("Could not set X509 certificate "
"not after time.") # pragma: no cover
def get_not_after(self): def get_not_after(self):
"""Get the 'not after' date field as seconds since the epoch.""" """Get the 'not after' date field as seconds since the epoch."""
not_after = self._lib.X509_get_notAfter(self._certObj) validity = self.get_validity()
not_after = validity['notAfter']
return utils.asn1_time_to_timestamp(not_after) return utils.asn1_time_to_timestamp(not_after)
def set_pubkey(self, pkey): def set_pubkey(self, pkey):
"""Set the public key field. """Set the public key field.
:param pkey: The public key, an EVP_PKEY ssl type :param pkey: The public key, rfc2459.SubjectPublicKeyInfo description
""" """
ret = self._lib.X509_set_pubkey(self._certObj, pkey) self._cert['tbsCertificate']['subjectPublicKeyInfo'] = pkey
if ret == 0:
raise X509CertificateError("Could not set X509 certificate "
"pubkey.") # pragma: no cover
def get_subject(self): def get_subject(self):
"""Get the subject name field value. """Get the subject name field value.
:return: An X509Name object instance :return: An X509Name object instance
""" """
val = self._lib.X509_get_subject_name(self._certObj) val = self._cert['tbsCertificate']['subject'][0]
if val == self._ffi.NULL:
raise X509CertificateError("Could not get subject from X509 "
"certificate.") # pragma: no cover
return name.X509Name(val) return name.X509Name(val)
def set_subject(self, subject): def set_subject(self, subject):
@ -197,10 +167,9 @@ class X509Certificate(object):
:param subject: An X509Name object instance :param subject: An X509Name object instance
""" """
val = subject._name_obj val = subject._name_obj
ret = self._lib.X509_set_subject_name(self._certObj, val) if self._cert['tbsCertificate']['subject'] is None:
if ret == 0: self._cert['tbsCertificate']['subject'] = rfc2459.Name()
raise X509CertificateError("Could not set X509 certificate " self._cert['tbsCertificate']['subject'][0] = val
"subject.") # pragma: no cover
def set_issuer(self, issuer): def set_issuer(self, issuer):
"""Set the issuer name field value. """Set the issuer name field value.
@ -208,20 +177,16 @@ class X509Certificate(object):
:param issuer: An X509Name object instance :param issuer: An X509Name object instance
""" """
val = issuer._name_obj val = issuer._name_obj
ret = self._lib.X509_set_issuer_name(self._certObj, val) if self._cert['tbsCertificate']['issuer'] is None:
if ret == 0: self._cert['tbsCertificate']['issuer'] = rfc2459.Name()
raise X509CertificateError("Could not set X509 certificate " self._cert['tbsCertificate']['issuer'][0] = val
"issuer.") # pragma: no cover
def get_issuer(self): def get_issuer(self):
"""Get the issuer name field value. """Get the issuer name field value.
:return: An X509Name object instance :return: An X509Name object instance
""" """
val = self._lib.X509_get_issuer_name(self._certObj) val = self._cert['tbsCertificate']['issuer'][0]
if val == self._ffi.NULL:
raise X509CertificateError("Could not get subject from X509 "
"certificate.") # pragma: no cover
return name.X509Name(val) return name.X509Name(val)
def set_serial_number(self, serial): def set_serial_number(self, serial):
@ -232,14 +197,18 @@ class X509Certificate(object):
:param serial: The serial number, 32 bit integer :param serial: The serial number, 32 bit integer
""" """
asn1_int = self._lib.ASN1_INTEGER_new() self._cert['tbsCertificate']['serialNumber'] = serial
ret = self._lib.ASN1_INTEGER_set(asn1_int, serial)
if ret != 0: def _get_extensions(self):
ret = self._lib.X509_set_serialNumber(self._certObj, asn1_int) if self._cert['tbsCertificate']['extensions'] is None:
self._lib.ASN1_INTEGER_free(asn1_int) # this actually initialises the extensions tag rather than
if ret == 0: # assign None
raise X509CertificateError("Could not set X509 certificate " self._cert['tbsCertificate']['extensions'] = None
"serial number.") # pragma: no cover return self._cert['tbsCertificate']['extensions']
def get_extensions(self):
extensions = self._get_extensions()
return [extension.construct_extension(e) for e in extensions]
def add_extension(self, ext, index): def add_extension(self, ext, index):
"""Add an X509 V3 Certificate extension. """Add an X509 V3 Certificate extension.
@ -247,10 +216,11 @@ class X509Certificate(object):
:param ext: An X509Extension instance :param ext: An X509Extension instance
:param index: The index of the extension :param index: The index of the extension
""" """
ret = self._lib.X509_add_ext(self._certObj, ext._ext, index) if not isinstance(ext, extension.X509Extension):
if ret == 0: raise errors.X509Error("ext needs to be a pyasn1 extension")
raise X509CertificateError("Could not add X509 certificate "
"extension.") # pragma: no cover extensions = self._get_extensions()
extensions[index] = ext.as_asn1()
def sign(self, key, md='sha1'): def sign(self, key, md='sha1'):
"""Sign the X509 certificate with a key using a message digest algorithm """Sign the X509 certificate with a key using a message digest algorithm
@ -262,28 +232,44 @@ class X509Certificate(object):
- sha1 - sha1
- sha256 - sha256
""" """
mda = getattr(self._lib, "EVP_%s" % md, None) md = md.upper()
if mda is None:
msg = 'X509 signing error: Unknown algorithm {a}'.format(a=md) if isinstance(key, rsa.RSAPrivateKey):
raise X509CertificateError(msg) encryption = 'RSA'
ret = self._lib.X509_sign(self._certObj, key, mda()) elif isinstance(key, dsa.DSAPrivateKey):
if ret == 0: encryption = 'DSA'
raise X509CertificateError("X509 signing error: Could not sign " else:
" certificate.") # pragma: no cover raise errors.X509Error("Unknown key type: %s" % (key.__class__,))
hash_class = utils.get_hash_class(md)
signature_type = SIGNING_ALGORITHMS.get((encryption, md))
if signature_type is None:
raise errors.X509Error(
"Unknown encryption/hash combination %s/%s" % (encryption, md))
algo_id = rfc2459.AlgorithmIdentifier()
algo_id['algorithm'] = signature_type
if encryption == 'RSA':
algo_id['parameters'] = encoder.encode(asn1_univ.Null())
elif encryption == 'DSA':
pass # parameters should be omitted, see RFC3279
self._cert['tbsCertificate']['signature'] = algo_id
to_sign = encoder.encode(self._cert['tbsCertificate'])
if encryption == 'RSA':
signer = key.signer(padding.PKCS1v15(), hash_class())
elif encryption == 'DSA':
signer = key.signer(hash_class())
signer.update(to_sign)
signature = signer.finalize()
self._cert['signatureValue'] = "'%s'B" % (
utils.bytes_to_bin(signature),)
self._cert['signatureAlgorithm'] = algo_id
def as_der(self): def as_der(self):
"""Return this X509 certificate as DER encoded data.""" """Return this X509 certificate as DER encoded data."""
buf = None return encoder.encode(self._cert)
num = self._lib.i2d_X509(self._certObj, self._ffi.NULL)
if num != 0:
buf = self._ffi.new("unsigned char[]", num + 1)
buf_ptr = self._ffi.new("unsigned char**")
buf_ptr[0] = buf
num = self._lib.i2d_X509(self._certObj, buf_ptr)
else:
raise X509CertificateError("Could not encode X509 certificate "
"as DER.") # pragma: no cover
return buf
def get_fingerprint(self, md='md5'): def get_fingerprint(self, md='md5'):
"""Get the fingerprint of this X509 certificate. """Get the fingerprint of this X509 certificate.
@ -291,7 +277,11 @@ class X509Certificate(object):
:param md: The message digest algorthim used to compute the fingerprint :param md: The message digest algorthim used to compute the fingerprint
:return: The fingerprint encoded as a hex string :return: The fingerprint encoded as a hex string
""" """
der = self.as_der() hash_class = utils.get_hash_class(md)
md = message_digest.MessageDigest(md) if hash_class is None:
md.update(der) raise errors.X509Error(
return md.final() "Unknown hash %s" % (md,))
hasher = hashes.Hash(hash_class(),
backend=cio_backends.default_backend())
hasher.update(self.as_der())
return binascii.hexlify(hasher.finalize()).upper().decode('ascii')

318
anchor/X509/extension.py Normal file
View File

@ -0,0 +1,318 @@
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import
import functools
import netaddr
from pyasn1.codec.der import decoder
from pyasn1.codec.der import encoder
from pyasn1.type import constraint as asn1_constraint
from pyasn1.type import namedtype as asn1_namedtype
from pyasn1.type import univ as asn1_univ
from pyasn1_modules import rfc2459 # X509v3
from anchor.X509 import errors
from anchor.X509 import utils
EXTENSION_NAMES = {
rfc2459.id_ce_policyConstraints: 'policyConstraints',
rfc2459.id_ce_basicConstraints: 'basicConstraints',
rfc2459.id_ce_subjectDirectoryAttributes: 'subjectDirectoryAttributes',
rfc2459.id_ce_deltaCRLIndicator: 'deltaCRLIndicator',
rfc2459.id_ce_cRLDistributionPoints: 'cRLDistributionPoints',
rfc2459.id_ce_issuingDistributionPoint: 'issuingDistributionPoint',
rfc2459.id_ce_nameConstraints: 'nameConstraints',
rfc2459.id_ce_certificatePolicies: 'certificatePolicies',
rfc2459.id_ce_policyMappings: 'policyMappings',
rfc2459.id_ce_privateKeyUsagePeriod: 'privateKeyUsagePeriod',
rfc2459.id_ce_keyUsage: 'keyUsage',
rfc2459.id_ce_authorityKeyIdentifier: 'authorityKeyIdentifier',
rfc2459.id_ce_subjectKeyIdentifier: 'subjectKeyIdentifier',
rfc2459.id_ce_certificateIssuer: 'certificateIssuer',
rfc2459.id_ce_subjectAltName: 'subjectAltName',
rfc2459.id_ce_issuerAltName: 'issuerAltName',
}
LONG_KEY_USAGE_NAMES = {
"Digital Signature": "digitalSignature",
"Non Repudiation": "nonRepudiation",
"Key Encipherment": "keyEncipherment",
"Data Encipherment": "dataEncipherment",
"Key Agreement": "keyAgreement",
"Certificate Sign": "keyCertSign",
"CRL Sign": "cRLSign",
"Encipher Only": "encipherOnly",
"Decipher Only": "decipherOnly",
}
def uses_ext_value(f):
"""Wrapper allowing reading of extension value.
Because the value is normally saved in a (double) serialised way, it's
not easily accessible to the member methods. This is made easier by
unpacking the extension value into an extra argument.
"""
@functools.wraps(f)
def ext_value_filled(self, *args, **kwargs):
kwargs['ext_value'] = self._get_value()
return f(self, *args, **kwargs)
return ext_value_filled
def modifies_ext_value(f):
"""Wrapper allowing modification of extension value.
Because the value is normally saved in a (double) serialised way, it's
not easily accessible to the member methods. This is made easier by
unpacking the extension value into an extra argument.
New value needs to be returned from the method.
"""
@functools.wraps(f)
def ext_value_filled(self, *args, **kwargs):
value = self._get_value()
kwargs['ext_value'] = value
# since some elements like NamedValue are pure value types, there is
# no interface to modify them and new versions have to be returned
value = f(self, *args, **kwargs)
self._set_value(value)
return ext_value_filled
class BasicConstraints(asn1_univ.Sequence):
"""Custom BasicConstraint implementation until pyasn1_modules is fixes."""
componentType = asn1_namedtype.NamedTypes(
asn1_namedtype.DefaultedNamedType('cA', asn1_univ.Boolean(False)),
asn1_namedtype.OptionalNamedType(
'pathLenConstraint',
asn1_univ.Integer().subtype(
subtypeSpec=asn1_constraint.ValueRangeConstraint(0, 64)))
)
class X509Extension(object):
"""Abstraction for the pyasn1 Extension structures.
The object should normally be constructed using `construct_extension`,
which will choose the right extension type based on the id.
Each extension has an immutable oid and a spec of the internal value
representation.
Unknown extension types can be still represented by the
X509Extension object and copied/serialised without understanding the
value details. The value will not be displayed properly in the logs
in the case.
"""
_oid = None
spec = None
"""An X509 V3 Certificate extension."""
def __init__(self, ext=None):
if ext is None:
if self.spec is None:
raise errors.X509Error("cannot create generic extension")
self._ext = rfc2459.Extension()
self._ext['extnID'] = self._oid
self._set_value(self._get_default_value())
else:
if not isinstance(ext, rfc2459.Extension):
raise errors.X509Error("extension has incorrect type")
self._ext = ext
@classmethod
def _get_default_value(cls):
# if there are any non-optional fields, this needs to be defined in
# the class
return cls.spec()
def __str__(self):
return "%s: %s" % (self.get_name(), self.get_value_as_str())
def get_value_as_str(self):
return "<unknown>"
def get_oid(self):
return self._ext['extnID']
def get_name(self):
"""Get the extension name as a python string."""
oid = self.get_oid()
return EXTENSION_NAMES.get(oid, oid)
def get_critical(self):
return self._ext['critical']
def set_critical(self, critical):
self._ext['critical'] = critical
def _get_value(self):
value_der = decoder.decode(self._ext['extnValue'])[0]
return decoder.decode(value_der, asn1Spec=self.spec())[0]
def _set_value(self, value):
if not isinstance(value, self.spec):
raise errors.X509Error("extension value has incorrect type")
self._ext['extnValue'] = encoder.encode(rfc2459.univ.OctetString(
encoder.encode(value)))
def as_der(self):
return encoder.encode(self._ext)
def as_asn1(self):
return self._ext
class X509ExtensionBasicConstraints(X509Extension):
spec = BasicConstraints
_oid = rfc2459.id_ce_basicConstraints
@uses_ext_value
def get_ca(self, ext_value=None):
return bool(ext_value['cA'])
@modifies_ext_value
def set_ca(self, ca, ext_value=None):
ext_value['cA'] = ca
return ext_value
@uses_ext_value
def get_path_len_constraint(self, ext_value=None):
return ext_value['pathLenConstraint']
@modifies_ext_value
def set_path_len_constraint(self, length, ext_value=None):
ext_value['pathLenConstraint'] = length
return ext_value
def __str__(self):
return "basicConstraints: CA: %s, pathLen: %s" % (
str(self.get_ca()).upper(), self.get_path_len_constraint())
class X509ExtensionKeyUsage(X509Extension):
spec = rfc2459.KeyUsage
_oid = rfc2459.id_ce_keyUsage
fields = dict(spec.namedValues.namedValues)
inv_fields = dict((v, k) for k, v in spec.namedValues.namedValues)
@classmethod
def _get_default_value(cls):
# if there are any non-optional fields, this needs to be defined in
# the class
return cls.spec("''B")
@uses_ext_value
def get_usage(self, usage, ext_value=None):
usage = LONG_KEY_USAGE_NAMES.get(usage, usage)
pos = self.fields[usage]
if pos >= len(ext_value):
return False
return bool(ext_value[pos])
@uses_ext_value
def get_all_usages(self, ext_value=None):
return [self.inv_fields[i] for i, enabled in enumerate(ext_value)
if enabled]
@modifies_ext_value
def set_usage(self, usage, state, ext_value=None):
usage = LONG_KEY_USAGE_NAMES.get(usage, usage)
pos = self.fields[usage]
values = [x for x in ext_value]
if state:
while pos >= len(values):
values.append(0)
values[pos] = 1
else:
if pos < len(values):
values[pos] = 0
bits = ''.join(str(x) for x in values)
return self.spec("'%s'B" % bits)
def __str__(self):
return "keyUsage: " + ", ".join(self.get_all_usages())
class X509ExtensionSubjectAltName(X509Extension):
spec = rfc2459.SubjectAltName
_oid = rfc2459.id_ce_subjectAltName
@uses_ext_value
def get_dns_ids(self, ext_value=None):
dns_ids = []
for name in ext_value:
if name.getName() != 'dNSName':
continue
component = name.getComponent()
dns_id = component.asOctets().decode(component.encoding)
dns_ids.append(dns_id)
return dns_ids
@uses_ext_value
def get_ips(self, ext_value=None):
ips = []
for name in ext_value:
if name.getName() != 'iPAddress':
continue
ips.append(utils.asn1_to_netaddr(name.getComponent()))
return ips
@modifies_ext_value
def add_dns_id(self, dns_id, ext_value=None):
# TODO(stan) validate dns_id
new_pos = len(ext_value)
ext_value[new_pos] = None
ext_value[new_pos]['dNSName'] = dns_id
return ext_value
@modifies_ext_value
def add_ip(self, ip, ext_value=None):
if not isinstance(ip, netaddr.IPAddress):
raise errors.X509Error("not a real ip address provided")
new_pos = len(ext_value)
ext_value[new_pos] = None
ext_value[new_pos]['iPAddress'] = utils.netaddr_to_asn1(ip)
return ext_value
@uses_ext_value
def __str__(self, ext_value=None):
entries = ["DNS:%s" % (x,) for x in self.get_dns_ids()]
entries += ["IP:%s" % (x,) for x in self.get_ips()]
return "subjectAltName: " + ", ".join(entries)
EXTENSION_CLASSES = {
rfc2459.id_ce_basicConstraints: X509ExtensionBasicConstraints,
rfc2459.id_ce_keyUsage: X509ExtensionKeyUsage,
rfc2459.id_ce_subjectAltName: X509ExtensionSubjectAltName,
}
def construct_extension(ext):
"""Construct an extension object of the right type.
While X509Extension can provide basic access to the extension elements,
it cannot parse details of extensions. This function detects which type
should be used based on the extension id.
If the type is unknown, generic X509Extension is used instead.
"""
if not isinstance(ext, rfc2459.Extension):
raise errors.X509Error("extension has incorrect type")
ext_class = EXTENSION_CLASSES.get(ext['extnID'], X509Extension)
return ext_class(ext)

View File

@ -1,93 +0,0 @@
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import
from cryptography.hazmat.backends.openssl import backend
import binascii
class MessageDigestError(Exception):
def __init__(self, what):
super(MessageDigestError, self).__init__(what)
class MessageDigest(object):
"""Compute a message digest from input data."""
@staticmethod
def getValidAlgorithms():
"""Get a list of available valid hash algorithms."""
algs = [
"md5",
"ripemd160",
"sha224",
"sha256",
"sha384",
"sha512"
]
ret = []
for alg in algs:
if getattr(backend._lib, "EVP_%s" % alg, None) is not None:
ret.append(alg)
return ret
def __init__(self, algo):
self._lib = backend._lib
self._ffi = backend._ffi
md = getattr(self._lib, "EVP_%s" % algo, None)
if md is None:
msg = 'MessageDigest error: unknown algorithm {a}'.format(a=algo)
raise MessageDigestError(msg)
ret = 0
ctx = self._lib.EVP_MD_CTX_create()
if ctx != self._ffi.NULL:
self.ctx = ctx
self.mda = md()
ret = self._lib.EVP_DigestInit_ex(self.ctx,
self.mda,
self._ffi.NULL)
if ret == 0:
raise MessageDigestError(
"Could not setup message digest context.") # pragma: no cover
def __del__(self):
if getattr(self, 'ctx', None):
self._lib.EVP_MD_CTX_cleanup(self.ctx)
self._lib.EVP_MD_CTX_destroy(self.ctx)
def update(self, data):
"""Add more data to the digest."""
ret = self._lib.EVP_DigestUpdate(self.ctx, data, len(data))
if ret == 0:
raise MessageDigestError(
"Failed to update message digest data.") # pragma: no cover
def final(self):
"""get the final resulting digest value.
Note that you should not call update() with additional data after using
final.
"""
sz = self._lib.EVP_MD_size(self.mda)
data = self._ffi.new("char[]", sz)
ret = self._lib.EVP_DigestFinal_ex(self.ctx, data, self._ffi.NULL)
if ret == 0:
raise MessageDigestError(
"Failed to get message digest.") # pragma: no cover
digest = self._ffi.string(data)
return binascii.hexlify(digest).decode('ascii').upper()

View File

@ -13,21 +13,64 @@
from __future__ import absolute_import from __future__ import absolute_import
from cryptography.hazmat.backends.openssl import backend from pyasn1.codec.der import decoder
from pyasn1.codec.der import encoder
from pyasn1.type import error as asn1_error
from pyasn1.type import univ as asn1_univ
from pyasn1_modules import rfc2459
from anchor.X509 import errors from anchor.X509 import errors
from anchor.X509 import utils
OID_commonName = rfc2459.id_at_commonName
OID_localityName = rfc2459.id_at_localityName
OID_stateOrProvinceName = rfc2459.id_at_stateOrProvinceName
OID_organizationName = rfc2459.id_at_organizationName
OID_organizationalUnitName = rfc2459.id_at_organizationalUnitName
OID_countryName = rfc2459.id_at_countryName
OID_pkcs9_emailAddress = rfc2459.emailAddress
OID_surname = rfc2459.id_at_sutname
OID_givenName = rfc2459.id_at_givenName
NID_countryName = backend._lib.NID_countryName name_oids = {
NID_stateOrProvinceName = backend._lib.NID_stateOrProvinceName rfc2459.id_at_name: rfc2459.X520name,
NID_localityName = backend._lib.NID_localityName rfc2459.id_at_sutname: rfc2459.X520name,
NID_organizationName = backend._lib.NID_organizationName rfc2459.id_at_givenName: rfc2459.X520name,
NID_organizationalUnitName = backend._lib.NID_organizationalUnitName rfc2459.id_at_initials: rfc2459.X520name,
NID_commonName = backend._lib.NID_commonName rfc2459.id_at_generationQualifier: rfc2459.X520name,
NID_pkcs9_emailAddress = backend._lib.NID_pkcs9_emailAddress rfc2459.id_at_commonName: rfc2459.X520CommonName,
NID_surname = backend._lib.NID_surname rfc2459.id_at_localityName: rfc2459.X520LocalityName,
NID_givenName = backend._lib.NID_givenName rfc2459.id_at_stateOrProvinceName: rfc2459.X520StateOrProvinceName,
rfc2459.id_at_organizationName: rfc2459.X520OrganizationName,
rfc2459.id_at_organizationalUnitName: rfc2459.X520OrganizationalUnitName,
rfc2459.id_at_title: rfc2459.X520Title,
rfc2459.id_at_dnQualifier: rfc2459.X520dnQualifier,
rfc2459.id_at_countryName: rfc2459.X520countryName,
rfc2459.emailAddress: rfc2459.Pkcs9email,
}
code_names = {
rfc2459.id_at_commonName: "CN",
rfc2459.id_at_localityName: "L",
rfc2459.id_at_stateOrProvinceName: "ST",
rfc2459.id_at_organizationName: "O",
rfc2459.id_at_organizationalUnitName: "OU",
rfc2459.id_at_countryName: "C",
rfc2459.id_at_givenName: "GN",
rfc2459.id_at_sutname: "SN",
rfc2459.emailAddress: "emailAddress",
}
short_names = {
rfc2459.id_at_commonName: "commonName",
rfc2459.id_at_localityName: "localityName",
rfc2459.id_at_stateOrProvinceName: "stateOrProvinceName",
rfc2459.id_at_organizationName: "organizationName",
rfc2459.id_at_organizationalUnitName: "organizationalUnitName",
rfc2459.id_at_countryName: "countryName",
rfc2459.id_at_givenName: "givenName",
rfc2459.id_at_sutname: "surname",
rfc2459.emailAddress: "emailAddress",
}
class X509Name(object): class X509Name(object):
@ -35,104 +78,90 @@ class X509Name(object):
class Entry(): class Entry():
"""An X509 Name sub-entry object.""" """An X509 Name sub-entry object."""
def __init__(self, obj, parent): def __init__(self, obj):
self._parent = parent self._obj = obj
self._lib = backend._lib
self._ffi = backend._ffi
self._entry = obj
def __str__(self): def __str__(self):
return "%s: %s" % (self.get_name(), self.get_value()) return "%s: %s" % (self.get_name(), self.get_value())
def get_oid(self):
return self._obj[0]['type']
def get_name(self): def get_name(self):
"""Get the name of this entry. """Get the name of this entry.
:return: entry name as a python string :return: entry name as a python string
""" """
asn1_obj = self._lib.X509_NAME_ENTRY_get_object(self._entry) oid = self.get_oid()
buf = self._ffi.new('char[]', 1024) return short_names.get(oid, str(oid))
ret = self._lib.OBJ_obj2txt(buf, 1024, asn1_obj, 0)
if ret == 0: def get_code(self):
raise errors.X509Error("Could not convert ASN1_OBJECT to " """Get the name of this entry.
"string.") # pragma: no cover
return self._ffi.string(buf).decode('ascii') :return: entry name as a python string
"""
oid = self.get_oid()
return code_names.get(oid, str(oid))
def get_value(self): def get_value(self):
"""Get the value of this entry. """Get the value of this entry.
:return: entry value as a python string :return: entry value as a python string
""" """
val = self._lib.X509_NAME_ENTRY_get_data(self._entry) value = self._obj[0]['value']
return utils.asn1_string_to_utf8(val) der = value.asOctets()
name_spec = name_oids[self.get_oid()]()
value = decoder.decode(der, asn1Spec=name_spec)[0]
if hasattr(value, 'getComponent'):
value = value.getComponent()
return value.asOctets().decode(value.encoding)
def __init__(self, name_obj=None): def __init__(self, name_obj=None):
self._lib = backend._lib
self._ffi = backend._ffi
if name_obj is not None: if name_obj is not None:
self._name_obj = self._lib.X509_NAME_dup(name_obj) if not isinstance(name_obj, rfc2459.RDNSequence):
if self._name_obj == self._ffi.NULL: raise TypeError("name is not an RDNSequence")
raise errors.X509Error("Failed to copy X509_NAME " # TODO(stan): actual copy
"object.") # pragma: no cover self._name_obj = name_obj
else: else:
self._name_obj = self._lib.X509_NAME_new() self._name_obj = rfc2459.RDNSequence()
if self._name_obj == self._ffi.NULL:
raise errors.X509Error("Failed to create "
"X509_NAME object.") # pragma: no cover
def __del__(self):
self._lib.X509_NAME_free(self._name_obj)
def __str__(self): def __str__(self):
# NOTE(tkelsey): we need to pass in a max size, so why not 1024 return '/' + '/'.join("%s=%s" % (e.get_code(), e.get_value())
val = self._lib.X509_NAME_oneline(self._name_obj, self._ffi.NULL, 1024) for e in self)
if val == self._ffi.NULL:
raise errors.X509Error("Could not convert"
" X509_NAME to string.") # pragma: no cover
val = self._ffi.gc(val, self._lib.OPENSSL_free)
return self._ffi.string(val).decode('ascii')
def __len__(self): def __len__(self):
return self._lib.X509_NAME_entry_count(self._name_obj) return len(self._name_obj)
def __getitem__(self, idx): def __getitem__(self, idx):
if not (0 <= idx < self.entry_count()): return X509Name.Entry(self._name_obj[idx])
raise IndexError("index out of range")
ent = self._lib.X509_NAME_get_entry(self._name_obj, idx)
return X509Name.Entry(ent, self)
def __iter__(self): def __iter__(self):
for i in range(self.entry_count()): for i in range(len(self)):
yield self[i] yield self[i]
def add_name_entry(self, nid, text): def add_name_entry(self, oid, text):
"""Add a name entry by its NID name.""" if not isinstance(oid, asn1_univ.ObjectIdentifier):
ret = self._lib.X509_NAME_add_entry_by_NID( raise errors.X509Error("oid '%s' is not valid" % (oid,))
self._name_obj, nid, entry = rfc2459.RelativeDistinguishedName()
self._lib.MBSTRING_UTF8, entry[0] = rfc2459.AttributeTypeAndValue()
text.encode('utf8'), -1, -1, 0) entry[0]['type'] = oid
name_type = name_oids[oid]
try:
if name_type in (rfc2459.X520countryName, rfc2459.Pkcs9email):
val = name_type(text)
else:
val = name_type()
val['utf8String'] = text
except asn1_error.ValueConstraintError:
raise errors.X509Error("Name '%s' is not valid" % text)
entry[0]['value'] = rfc2459.AttributeValue(encoder.encode(val))
self._name_obj[len(self)] = entry
if ret != 1: def get_entries_by_oid(self, oid):
raise errors.X509Error("Failed to add name entry: '%s' '%s'" % (
nid, text))
def entry_count(self):
"""Get the number of entries in the name object."""
return self._lib.X509_NAME_entry_count(self._name_obj)
def get_entries_by_nid(self, nid):
"""Get a name entry corresponding to an NID name. """Get a name entry corresponding to an NID name.
:param nid: an NID for the new name entry :param nid: an NID for the new name entry
:return: An X509Name.Entry object :return: An X509Name.Entry object
""" """
out = [] return [entry for entry in self if entry.get_oid() == oid]
idx = self._lib.X509_NAME_get_index_by_NID(self._name_obj, nid, -1)
while idx != -1:
val = self._lib.X509_NAME_get_entry(self._name_obj, idx)
if val != self._ffi.NULL:
out.append(X509Name.Entry(val, self))
idx = self._lib.X509_NAME_get_index_by_NID(self._name_obj,
nid, idx)
return out

View File

@ -13,13 +13,23 @@
from __future__ import absolute_import from __future__ import absolute_import
from cryptography.hazmat.backends.openssl import backend import io
from pyasn1.codec.der import decoder
from pyasn1.codec.der import encoder
from pyasn1.type import univ as asn1_univ
from pyasn1_modules import pem
from pyasn1_modules import rfc2314 # PKCS#10 / CSR
from pyasn1_modules import rfc2459 # X509
from anchor.X509 import certificate
from anchor.X509 import errors from anchor.X509 import errors
from anchor.X509 import extension
from anchor.X509 import name from anchor.X509 import name
OID_extensionRequest = asn1_univ.ObjectIdentifier('1.2.840.113549.1.9.14')
class X509CsrError(errors.X509Error): class X509CsrError(errors.X509Error):
def __init__(self, what): def __init__(self, what):
super(X509CsrError, self).__init__(what) super(X509CsrError, self).__init__(what)
@ -27,84 +37,108 @@ class X509CsrError(errors.X509Error):
class X509Csr(object): class X509Csr(object):
"""An X509 Certificate Signing Request.""" """An X509 Certificate Signing Request."""
def __init__(self): def __init__(self, csr=None):
self._lib = backend._lib if csr is None:
self._ffi = backend._ffi self._csr = rfc2314.CertificationRequest()
csrObj = self._lib.X509_REQ_new() else:
if csrObj == self._ffi.NULL: self._csr = csr
raise X509CsrError(
"Could not create X509 CSR Object.") # pragma: no cover
self._csrObj = csrObj @staticmethod
def from_open_file(f):
try:
der_content = pem.readPemFromFile(
f, startMarker='-----BEGIN CERTIFICATE REQUEST-----',
endMarker='-----END CERTIFICATE REQUEST-----')
csr = decoder.decode(der_content,
asn1Spec=rfc2314.CertificationRequest())[0]
return X509Csr(csr)
except Exception:
raise X509CsrError("Could not read X509 certificate from "
"PEM data.")
def __del__(self): @staticmethod
if getattr(self, '_csrObj', None): def from_buffer(data):
self._lib.X509_REQ_free(self._csrObj)
def from_buffer(self, data, password=None):
"""Create this CSR from a buffer """Create this CSR from a buffer
:param data: The data buffer :param data: The data buffer
:param password: decryption password, if needed
""" """
if type(data) != bytes: return X509Csr.from_open_file(io.StringIO(data))
data = data.encode('ascii')
bio = backend._bytes_to_bio(data)
ptr = self._ffi.new("X509_REQ **")
ptr[0] = self._csrObj
ret = self._lib.PEM_read_bio_X509_REQ(bio[0], ptr,
self._ffi.NULL,
self._ffi.NULL)
if ret == self._ffi.NULL:
raise X509CsrError("Could not read X509 CSR from PEM data.")
def from_file(self, path, password=None): @staticmethod
def from_file(path):
"""Create this CSR from a file on disk """Create this CSR from a file on disk
:param path: Path to the file on disk :param path: Path to the file on disk
:param password: decryption password, if needed
""" """
data = None with open(path, 'r') as f:
with open(path, 'rb') as f: return X509Csr.from_open_file(f)
data = f.read()
self.from_buffer(data, password)
def get_pubkey(self): def get_pubkey(self):
"""Get the public key from the CSR """Get the public key from the CSR
:return: an OpenSSL EVP_PKEY object :return: an OpenSSL EVP_PKEY object
""" """
pkey = self._lib.X509_REQ_get_pubkey(self._csrObj) return self._csr['certificationRequestInfo']['subjectPublicKeyInfo']
if pkey == self._ffi.NULL:
raise X509CsrError(
"Could not get pubkey from X509 CSR.") # pragma: no cover
return pkey def get_request_info(self):
if self._csr['certificationRequestInfo'] is None:
self._csr['certificationRequestInfo'] = None
return self._csr['certificationRequestInfo']
def get_subject(self): def get_subject(self):
"""Get the subject name field from the CSR """Get the subject name field from the CSR
:return: an X509Name object :return: an X509Name object
""" """
subs = self._lib.X509_REQ_get_subject_name(self._csrObj) ri = self.get_request_info()
if subs == self._ffi.NULL: if ri['subject'] is None:
raise X509CsrError( ri['subject'] = None
"Could not get subject from X509 CSR.") # pragma: no cover # setup first RDN sequence
ri['subject'][0] = None
return name.X509Name(subs) subject = ri['subject'][0]
return name.X509Name(subject)
def get_extensions(self): def get_attributes(self):
ri = self.get_request_info()
if ri['attributes'] is None:
ri['attributes'] = None
return ri['attributes']
def get_extensions(self, ext_type=None):
"""Get the list of all X509 V3 Extensions on this CSR """Get the list of all X509 V3 Extensions on this CSR
:return: a list of X509Extension objects :return: a list of X509Extension objects
""" """
# TODO(tkelsey): I assume the ext list copies data and this is safe ext_attrs = [a for a in self.get_attributes()
# TODO(tkelsey): Error checking needed here if a['type'] == OID_extensionRequest]
ret = [] if len(ext_attrs) == 0:
exts = self._lib.X509_REQ_get_extensions(self._csrObj) return []
num = self._lib.sk_X509_EXTENSION_num(exts) else:
for i in range(0, num): exts_der = ext_attrs[0]['vals'][0].asOctets()
ext = self._lib.sk_X509_EXTENSION_value(exts, i) exts = decoder.decode(exts_der, asn1Spec=rfc2459.Extensions())[0]
ret.append(certificate.X509Extension(ext)) return [extension.construct_extension(e) for e in exts
self._lib.sk_X509_EXTENSION_free(exts) if ext_type is None or e['extnID'] == ext_type._oid]
return ret
def add_extension(self, ext):
if not isinstance(ext, extension.X509Extension):
raise errors.X509Error("ext is not an anchor X509Extension")
attributes = self.get_attributes()
ext_attrs = [a for a in attributes
if a['type'] == OID_extensionRequest]
if not ext_attrs:
new_attr_index = len(attributes)
attributes[new_attr_index] = None
ext_attr = attributes[new_attr_index]
ext_attr['type'] = OID_extensionRequest
ext_attr['vals'] = None
exts = rfc2459.Extensions()
else:
ext_attr = ext_attrs[0]
exts = decoder.decode(ext_attr['vals'][0].asOctets(),
asn1Spec=rfc2459.Extensions())[0]
new_ext_index = len(exts)
exts[new_ext_index] = ext._ext
ext_attr['vals'][0] = encoder.encode(exts)

View File

@ -15,38 +15,18 @@ from __future__ import absolute_import
import calendar import calendar
import datetime import datetime
import struct
from cryptography.hazmat.backends.openssl import backend from cryptography.hazmat import backends
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
import netaddr
from pyasn1.type import useful as asn1_useful
from pyasn1_modules import rfc2459
from anchor.X509 import errors from anchor.X509 import errors
def load_pem_private_key(key_data, passwd=None):
"""Load and return an OpenSSL EVP_PKEY public key object from a data buffer
:param key_data: The data buffer
:param passwd: Decryption password if neded (not used for now)
:return: an OpenSSL EVP_PKEY public key object
"""
# TODO(tkelsey): look at using backend.read_private_key
#
if type(key_data) != bytes:
key_data = key_data.encode('ascii')
lib = backend._lib
ffi = backend._ffi
data = backend._bytes_to_bio(key_data)
evp_pkey = lib.EVP_PKEY_new()
evp_pkey_ptr = ffi.new("EVP_PKEY**")
evp_pkey_ptr[0] = evp_pkey
evp_pkey = lib.PEM_read_bio_PrivateKey(data[0], evp_pkey_ptr,
ffi.NULL, ffi.NULL)
evp_pkey = ffi.gc(evp_pkey, lib.EVP_PKEY_free)
return evp_pkey
def create_timezone(minute_offset): def create_timezone(minute_offset):
"""Create a new timezone with a specified offset. """Create a new timezone with a specified offset.
@ -80,18 +60,17 @@ def asn1_time_to_timestamp(t):
:param t: ASN1_TIME to convert :param t: ASN1_TIME to convert
""" """
component = t.getComponent()
gen_time = backend._lib.ASN1_TIME_to_generalizedtime(t, backend._ffi.NULL) timestring = component.asOctets().decode(component.encoding)
if gen_time == backend._ffi.NULL: if isinstance(component, asn1_useful.UTCTime):
raise errors.ASN1TimeError("time conversion failure") if int(timestring[0]) >= 5:
timestring = "19" + timestring
try: else:
return asn1_generalizedtime_to_timestamp(gen_time) timestring = "20" + timestring
finally: return asn1_timestring_to_timestamp(timestring)
backend._lib.ASN1_GENERALIZEDTIME_free(gen_time)
def asn1_generalizedtime_to_timestamp(gt): def asn1_timestring_to_timestamp(timestring):
"""Convert from ASN1_GENERALIZEDTIME to UTC-based timestamp. """Convert from ASN1_GENERALIZEDTIME to UTC-based timestamp.
:param gt: ASN1_GENERALIZEDTIME to convert :param gt: ASN1_GENERALIZEDTIME to convert
@ -99,11 +78,8 @@ def asn1_generalizedtime_to_timestamp(gt):
# ASN1_GENERALIZEDTIME is actually a string in known formats, # ASN1_GENERALIZEDTIME is actually a string in known formats,
# so the conversion can be done in this code # so the conversion can be done in this code
string_time = backend._ffi.cast("ASN1_STRING*", gt) before_tz = timestring[:14]
res = asn1_string_to_utf8(string_time) tz_str = timestring[14:]
before_tz = res[:14]
tz_str = res[14:]
d = datetime.datetime.strptime(before_tz, "%Y%m%d%H%M%S") d = datetime.datetime.strptime(before_tz, "%Y%m%d%H%M%S")
if tz_str == 'Z': if tz_str == 'Z':
# YYYYMMDDhhmmssZ # YYYYMMDDhhmmssZ
@ -126,25 +102,85 @@ def timestamp_to_asn1_time(t):
""" """
d = datetime.datetime.utcfromtimestamp(t) d = datetime.datetime.utcfromtimestamp(t)
# use the ASN1_GENERALIZEDTIME format asn1time = rfc2459.Time()
time_str = d.strftime("%Y%m%d%H%M%SZ").encode('ascii') if d.year <= 2049:
asn1_time = backend._lib.ASN1_STRING_type_new( time_str = d.strftime("%y%m%d%H%M%SZ").encode('ascii')
backend._lib.V_ASN1_GENERALIZEDTIME) asn1time['utcTime'] = time_str
backend._lib.ASN1_STRING_set(asn1_time, time_str, len(time_str)) else:
asn1_gentime = backend._ffi.cast("ASN1_GENERALIZEDTIME*", asn1_time) time_str = d.strftime("%Y%m%d%H%M%SZ").encode('ascii')
if backend._lib.ASN1_GENERALIZEDTIME_check(asn1_gentime) == 0: asn1time['generalTime'] = time_str
raise errors.ASN1TimeError("timestamp not accepted by ASN1 check") return asn1time
# ASN1_GENERALIZEDTIME is a form of ASN1_TIME, so a pointer cast is valid
return backend._ffi.cast("ASN1_TIME*", asn1_time)
def asn1_string_to_utf8(asn1_string): # functions needed for converting the pyasn1 signature fields
buf = backend._ffi.new("unsigned char **") def bin_to_bytes(bits):
res = backend._lib.ASN1_STRING_to_UTF8(buf, asn1_string) """Convert bit string to byte string."""
if res < 0 or buf[0] == backend._ffi.NULL: bits = ''.join(str(b) for b in bits)
raise errors.ASN1StringError("cannot convert asn1 to python string") bits = _pad_byte(bits)
buf = backend._ffi.gc( octets = [bits[8*i:8*(i+1)] for i in range(len(bits)/8)]
buf, lambda buffer: backend._lib.OPENSSL_free(buffer[0]) bytes = [chr(int(x, 2)) for x in octets]
) return "".join(bytes)
return backend._ffi.buffer(buf[0], res)[:].decode('utf8')
# ord good for py2 and py3
local_ord = ord if str is bytes else lambda x: x
def _pad_byte(bits):
"""Pad a string of bits with zeros to make its length a multiple of 8."""
r = len(bits) % 8
return ((8-r) % 8)*'0' + bits
def bytes_to_bin(bytes):
"""Convert byte string to bit string."""
return "".join([_pad_byte(_int_to_bin(local_ord(byte))) for byte in bytes])
def _int_to_bin(n):
if n == 0 or n == 1:
return str(n)
elif n % 2 == 0:
return _int_to_bin(n // 2) + "0"
else:
return _int_to_bin(n // 2) + "1"
def get_hash_class(md):
return getattr(hashes, md.upper(), None)
def get_private_key_from_bytes(data):
key = serialization.load_pem_private_key(
data, None, backend=backends.default_backend())
return key
def get_private_key_from_file(path):
with open(path, 'rb') as f:
return get_private_key_from_bytes(f.read())
def asn1_to_netaddr(octet_string):
"""Translate the ASN1 IP format to netaddr object."""
if not isinstance(octet_string, rfc2459.univ.OctetString):
raise TypeError("not an OctetString")
ip_bytes = octet_string.asOctets()
if len(ip_bytes) == 4:
ip_num = struct.unpack(">I", ip_bytes)[0]
return netaddr.IPAddress(ip_num, 4)
elif len(ip_bytes) == 16:
ip_num_front, ip_num_back = struct.unpack(">QQ", ip_bytes)
ip_num = ip_num_front << 64 | ip_num_back
return netaddr.IPAddress(ip_num, 6)
else:
raise TypeError("ip address is neither v4 nor v6")
def netaddr_to_asn1(ip):
"""Translate the netaddr object to ASN1 IP format."""
if not isinstance(ip, netaddr.IPAddress):
raise errors.X509Error("not a real ip address provided")
return bytes(ip.packed)

View File

@ -25,7 +25,7 @@ from anchor import jsonloader
from anchor import validators from anchor import validators
from anchor.X509 import certificate from anchor.X509 import certificate
from anchor.X509 import signing_request from anchor.X509 import signing_request
from anchor.X509 import utils as X509_utils from anchor.X509 import utils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -54,8 +54,7 @@ def parse_csr(csr, encoding):
# load the CSR into the backend X509 library # load the CSR into the backend X509 library
try: try:
out_req = signing_request.X509Csr() out_req = signing_request.X509Csr.from_buffer(csr)
out_req.from_buffer(csr)
return out_req return out_req
except Exception as e: except Exception as e:
logger.exception("Exception while parsing the CSR: %s", e) logger.exception("Exception while parsing the CSR: %s", e)
@ -132,17 +131,14 @@ def sign(csr):
:param csr: X509 certificate signing request :param csr: X509 certificate signing request
""" """
try: try:
ca = certificate.X509Certificate() ca = certificate.X509Certificate.from_file(
ca.from_file(jsonloader.conf.ca["cert_path"]) jsonloader.conf.ca["cert_path"])
except Exception as e: except Exception as e:
logger.exception("Cannot load the signing CA: %s", e) logger.exception("Cannot load the signing CA: %s", e)
pecan.abort(500, "certificate signing error") pecan.abort(500, "certificate signing error")
try: try:
key_data = None key = utils.get_private_key_from_file(jsonloader.conf.ca['key_path'])
with open(jsonloader.conf.ca["key_path"]) as f:
key_data = f.read()
key = X509_utils.load_pem_private_key(key_data)
except Exception as e: except Exception as e:
logger.exception("Cannot load the signing CA key: %s", e) logger.exception("Cannot load the signing CA key: %s", e)
pecan.abort(500, "certificate signing error") pecan.abort(500, "certificate signing error")
@ -182,7 +178,7 @@ def sign(csr):
cert_pem = new_cert.as_pem() cert_pem = new_cert.as_pem()
with open(path, "wb") as f: with open(path, "w") as f:
f.write(cert_pem) f.write(cert_pem)
return cert_pem return cert_pem

View File

@ -17,6 +17,7 @@ import logging
import netaddr import netaddr
from anchor.X509 import extension
from anchor.X509 import name as x509_name from anchor.X509 import name as x509_name
@ -29,7 +30,7 @@ class ValidationError(Exception):
def csr_get_cn(csr): def csr_get_cn(csr):
name = csr.get_subject() name = csr.get_subject()
data = name.get_entries_by_nid(x509_name.NID_commonName) data = name.get_entries_by_oid(x509_name.OID_commonName)
if len(data) > 0: if len(data) > 0:
return data[0].get_value() return data[0].get_value()
else: else:
@ -51,19 +52,14 @@ def check_domains(domain, allowed_domains):
def iter_alternative_names(csr, types, fail_other_types=True): def iter_alternative_names(csr, types, fail_other_types=True):
for ext in csr.get_extensions(): for ext in csr.get_extensions():
if ext.get_name() == "subjectAltName": if isinstance(ext, extension.X509ExtensionSubjectAltName):
alternatives = [alt.strip() for alt in ext.get_value().split(',')] # TODO(stan): fail on other types
for alternative in alternatives: if 'DNS' in types:
parts = alternative.split(':', 1) for dns_id in ext.get_dns_ids():
if len(parts) != 2: yield ('DNS', dns_id)
# it has at least one part, so parts[0] is valid if 'IP Address' in types:
raise ValidationError("Alt name should have 2 parts, but " for ip in ext.get_ips():
"found: '%s'" % parts[0]) yield ('IP Address', ip)
if parts[0] in types:
yield parts
elif fail_other_types:
raise ValidationError("Alt name '%s' has unexpected type "
"'%s'" % (parts[1], parts[0]))
def check_networks(ip, allowed_networks): def check_networks(ip, allowed_networks):
@ -91,15 +87,15 @@ def common_name(csr, allowed_domains=[], allowed_networks=[], **kwargs):
alt_present = any(ext.get_name() == "subjectAltName" alt_present = any(ext.get_name() == "subjectAltName"
for ext in csr.get_extensions()) for ext in csr.get_extensions())
CNs = csr.get_subject().get_entries_by_nid(x509_name.NID_commonName) CNs = csr.get_subject().get_entries_by_oid(x509_name.OID_commonName)
if len(CNs) > 1: if len(CNs) > 1:
raise ValidationError("Too many CNs in the request") raise ValidationError("Too many CNs in the request")
if not alt_present:
# rfc5280#section-4.2.1.6 says so # rfc5280#section-4.2.1.6 says so
if len(CNs) == 0: if len(CNs) == 0 and not alt_present:
raise ValidationError("Alt subjects have to exist if the main" raise ValidationError("Alt subjects have to exist if the main"
" subject doesn't") " subject doesn't")
if len(CNs) > 0: if len(CNs) > 0:
cn = csr_get_cn(csr) cn = csr_get_cn(csr)
@ -122,7 +118,7 @@ def alternative_names(csr, allowed_domains=[], **kwargs):
the list of known suffixes, or network ranges. the list of known suffixes, or network ranges.
""" """
for name_type, name in iter_alternative_names(csr, ['DNS']): for _, name in iter_alternative_names(csr, ['DNS']):
if not check_domains(name, allowed_domains): if not check_domains(name, allowed_domains):
raise ValidationError("Domain '%s' not allowed (doesn't" raise ValidationError("Domain '%s' not allowed (doesn't"
" match known domains)" " match known domains)"
@ -142,9 +138,8 @@ def alternative_names_ip(csr, allowed_domains=[], allowed_networks=[],
raise ValidationError("Domain '%s' not allowed (doesn't" raise ValidationError("Domain '%s' not allowed (doesn't"
" match known domains)" % name) " match known domains)" % name)
if name_type == 'IP Address': if name_type == 'IP Address':
ip = netaddr.IPAddress(name) if not check_networks(name, allowed_networks):
if not check_networks(ip, allowed_networks): raise ValidationError("IP '%s' not allowed (doesn't"
raise ValidationError("Address '%s' not allowed (doesn't"
" match known networks)" % name) " match known networks)" % name)
@ -156,7 +151,7 @@ def blacklist_names(csr, domains=[], **kwargs):
"consider disabling the step or providing a list") "consider disabling the step or providing a list")
return return
CNs = csr.get_subject().get_entries_by_nid(x509_name.NID_commonName) CNs = csr.get_subject().get_entries_by_oid(x509_name.OID_commonName)
if len(CNs) > 0: if len(CNs) > 0:
cn = csr_get_cn(csr) cn = csr_get_cn(csr)
if check_domains(cn, domains): if check_domains(cn, domains):
@ -198,45 +193,41 @@ def extensions(csr=None, allowed_extensions=[], **kwargs):
def key_usage(csr=None, allowed_usage=None, **kwargs): def key_usage(csr=None, allowed_usage=None, **kwargs):
"""Ensure only accepted key usages are specified.""" """Ensure only accepted key usages are specified."""
allowed = set(allowed_usage) allowed = set(extension.LONG_KEY_USAGE_NAMES.get(x, x) for x in
allowed_usage)
denied = set()
for ext in (csr.get_extensions() or []): for ext in (csr.get_extensions() or []):
if ext.get_name() == 'keyUsage': if isinstance(ext, extension.X509ExtensionKeyUsage):
usages = set(usage.strip() for usage in ext.get_value().split(',')) usages = set(ext.get_all_usages())
if usages & allowed != usages: denied = denied | (usages - allowed)
raise ValidationError("Found some not allowed key usages: %s" if denied:
% ', '.join(usages - allowed)) raise ValidationError("Found some not allowed key usages: %s"
% ', '.join(denied))
def ca_status(csr=None, ca_requested=False, **kwargs): def ca_status(csr=None, ca_requested=False, **kwargs):
"""Ensure the request has/hasn't got the CA flag.""" """Ensure the request has/hasn't got the CA flag."""
request_ca_flags = False
for ext in (csr.get_extensions() or []): for ext in (csr.get_extensions() or []):
ext_name = ext.get_name() if isinstance(ext, extension.X509ExtensionBasicConstraints):
if ext_name == 'basicConstraints': if ext.get_ca():
options = [opt.strip() for opt in ext.get_value().split(",")] if not ca_requested:
for option in options: raise ValidationError(
parts = option.split(":") "CA status requested, but not allowed")
if len(parts) != 2: request_ca_flags = True
raise ValidationError("Invalid basic constraints flag") elif isinstance(ext, extension.X509ExtensionKeyUsage):
has_cert_sign = ext.get_usage('keyCertSign')
if parts[0] == 'CA': has_crl_sign = ext.get_usage('cRLSign')
if parts[1] != str(ca_requested).upper(): if has_crl_sign or has_cert_sign:
raise ValidationError("Invalid CA status, 'CA:%s'" if not ca_requested:
" requested" % parts[1]) raise ValidationError(
elif parts[0] == 'pathlen': "Key usage doesn't match requested CA status "
# errr.. it's ok, I guess "(keyCertSign/cRLSign: %s/%s)"
pass % (has_cert_sign, has_crl_sign))
else: request_ca_flags = True
raise ValidationError("Invalid basic constraints option") if ca_requested and not request_ca_flags:
elif ext_name == 'keyUsage': raise ValidationError("CA flags required")
usages = set(usage.strip() for usage in ext.get_value().split(','))
has_cert_sign = ('Certificate Sign' in usages)
has_crl_sign = ('CRL Sign' in usages)
if ca_requested != has_cert_sign or ca_requested != has_crl_sign:
raise ValidationError("Key usage doesn't match requested CA"
" status (keyCertSign/cRLSign: %s/%s)"
% (has_cert_sign, has_crl_sign))
def source_cidrs(request=None, cidrs=None, **kwargs): def source_cidrs(request=None, cidrs=None, **kwargs):

View File

@ -2,6 +2,8 @@
# of appearance. Changing the order has an impact on the overall integration # of appearance. Changing the order has an impact on the overall integration
# process, which may cause wedges in the gate later. # process, which may cause wedges in the gate later.
cryptography>=0.9.1 # Apache-2.0 cryptography>=0.9.1 # Apache-2.0
pyasn1
pyasn1_modules
pecan>=0.8.0 pecan>=0.8.0
Paste Paste
netaddr>=0.7.12 netaddr>=0.7.12

View File

@ -0,0 +1,144 @@
# -*- coding:utf-8 -*-
#
# Copyright 2014 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import unittest
import netaddr
from pyasn1_modules import rfc2459 # X509v3
from anchor.X509 import errors
from anchor.X509 import extension
class TestExtensionBase(unittest.TestCase):
def test_no_spec(self):
with self.assertRaises(errors.X509Error):
extension.X509Extension()
def test_invalid_asn(self):
with self.assertRaises(errors.X509Error):
extension.X509Extension("foobar")
def test_unknown_extension_str(self):
asn1 = rfc2459.Extension()
asn1['extnID'] = rfc2459.univ.ObjectIdentifier('1.2.3.4')
asn1['critical'] = False
asn1['extnValue'] = "foobar"
ext = extension.X509Extension(asn1)
self.assertEqual("1.2.3.4: <unknown>", str(ext))
def test_construct(self):
asn1 = rfc2459.Extension()
asn1['extnID'] = rfc2459.univ.ObjectIdentifier('1.2.3.4')
asn1['critical'] = False
asn1['extnValue'] = "foobar"
ext = extension.construct_extension(asn1)
self.assertIsInstance(ext, extension.X509Extension)
def test_construct_invalid_type(self):
with self.assertRaises(errors.X509Error):
extension.construct_extension("foobar")
def test_critical(self):
asn1 = rfc2459.Extension()
asn1['extnID'] = rfc2459.univ.ObjectIdentifier('1.2.3.4')
asn1['critical'] = False
asn1['extnValue'] = "foobar"
ext = extension.construct_extension(asn1)
self.assertFalse(ext.get_critical())
ext.set_critical(True)
self.assertTrue(ext.get_critical())
class TestBasicConstraints(unittest.TestCase):
def setUp(self):
self.ext = extension.X509ExtensionBasicConstraints()
def test_str(self):
self.assertEqual(str(self.ext),
"basicConstraints: CA: FALSE, pathLen: None")
def test_ca(self):
self.ext.set_ca(True)
self.assertTrue(self.ext.get_ca())
self.ext.set_ca(False)
self.assertFalse(self.ext.get_ca())
def test_pathlen(self):
self.ext.set_path_len_constraint(1)
self.assertEqual(1, self.ext.get_path_len_constraint())
class TestKeyUsage(unittest.TestCase):
def setUp(self):
self.ext = extension.X509ExtensionKeyUsage()
def test_usage_set(self):
self.ext.set_usage('digitalSignature', True)
self.ext.set_usage('keyAgreement', False)
self.assertTrue(self.ext.get_usage('digitalSignature'))
self.assertFalse(self.ext.get_usage('keyAgreement'))
def test_usage_reset(self):
self.ext.set_usage('digitalSignature', True)
self.ext.set_usage('digitalSignature', False)
self.assertFalse(self.ext.get_usage('digitalSignature'))
def test_usage_unset(self):
self.assertFalse(self.ext.get_usage('keyAgreement'))
def test_get_all_usage(self):
self.ext.set_usage('digitalSignature', True)
self.ext.set_usage('keyAgreement', False)
self.ext.set_usage('keyEncipherment', True)
self.assertEqual(set(['digitalSignature', 'keyEncipherment']),
set(self.ext.get_all_usages()))
def test_str(self):
self.ext.set_usage('digitalSignature', True)
self.assertEqual("keyUsage: digitalSignature", str(self.ext))
class TestSubjectAltName(unittest.TestCase):
def setUp(self):
self.ext = extension.X509ExtensionSubjectAltName()
self.domain = 'some.domain'
self.ip = netaddr.IPAddress('1.2.3.4')
self.ip6 = netaddr.IPAddress('::1')
def test_dns_ids(self):
self.ext.add_dns_id(self.domain)
self.ext.add_ip(self.ip)
self.assertEqual([self.domain], self.ext.get_dns_ids())
def test_ips(self):
self.ext.add_dns_id(self.domain)
self.ext.add_ip(self.ip)
self.assertEqual([self.ip], self.ext.get_ips())
def test_ipv6(self):
self.ext.add_ip(self.ip6)
self.assertEqual([self.ip6], self.ext.get_ips())
def test_add_ip_invalid(self):
with self.assertRaises(errors.X509Error):
self.ext.add_ip("abcdef")
def test_str(self):
self.ext.add_dns_id(self.domain)
self.ext.add_ip(self.ip)
self.assertEqual("subjectAltName: DNS:some.domain, IP:1.2.3.4",
str(self.ext))

View File

@ -1,92 +0,0 @@
# -*- coding:utf-8 -*-
#
# Copyright 2014 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import unittest
from anchor.X509 import message_digest
class TestMessageDigest(unittest.TestCase):
data = b"this is test data to test with"
def setUp(self):
super(TestMessageDigest, self).setUp()
def tearDown(self):
super(TestMessageDigest, self).tearDown()
def test_bad_algo(self):
self.assertRaises(message_digest.MessageDigestError,
message_digest.MessageDigest,
'BAD')
def test_md5(self):
v = "B2F81E9F287884AF6A8B3E8EFB96C711"
md = message_digest.MessageDigest("md5")
md.update(TestMessageDigest.data)
ret = md.final()
self.assertEqual(ret, v)
def test_ripmed160(self):
v = "BA5CCC4574D676266D821269CA77BFFD7FD9FCB0"
md = message_digest.MessageDigest("ripemd160")
md.update(TestMessageDigest.data)
ret = md.final()
self.assertEqual(ret, v)
def test_sha224(self):
v = "675170C12E88D549DB0F608AD6857103D7B792F29FACFCC53173F178"
md = message_digest.MessageDigest("sha224")
md.update(TestMessageDigest.data)
ret = md.final()
self.assertEqual(ret, v)
def test_sha256(self):
v = "91F672E796E84BECC6F051A47D7392BD789AEA7D55090588F212CF041C862678"
md = message_digest.MessageDigest("sha256")
md.update(TestMessageDigest.data)
ret = md.final()
self.assertEqual(ret, v)
def test_sha384(self):
v = ("9667AF42DF2E6B81EE679757BB207A3F9BB7CED49CF838FF3ED8237C9B15291B"
"15")
md = message_digest.MessageDigest("sha384")
md.update(TestMessageDigest.data)
ret = md.final()
self.assertEqual(ret, v)
def test_sha512(self):
v = ("283B3ECD8AE687226C3EA46B59F65E5CA50A11735C9C14BED11F0CCB515707B5"
"1031145ED8AE4B35B24B91F26E70AC0ACAC37B5BEE933B28834FE6447D1298CB"
)
md = message_digest.MessageDigest("sha512")
md.update(TestMessageDigest.data)
ret = md.final()
self.assertEqual(ret, v)
def test_algorithms(self):
algs = [
"md5",
"ripemd160",
"sha224",
"sha256",
"sha384",
"sha512"
]
valid = message_digest.MessageDigest.getValidAlgorithms()
for alg in algs:
self.assertTrue(alg in valid)

View File

@ -14,70 +14,21 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import datetime
import unittest import unittest
import mock
from anchor.X509 import errors
from anchor.X509 import utils from anchor.X509 import utils
from cryptography.hazmat.backends.openssl import backend
class TestASN1String(unittest.TestCase):
# missing in cryptography.io
V_ASN1_UTF8STRING = 12
def test_utf8_string(self):
orig = u"test \u2603 snowman"
encoded = orig.encode('utf-8')
asn1string = backend._lib.ASN1_STRING_type_new(self.V_ASN1_UTF8STRING)
backend._lib.ASN1_STRING_set(asn1string, encoded, len(encoded))
res = utils.asn1_string_to_utf8(asn1string)
self.assertEqual(res, orig)
def test_invalid_string(self):
encoded = b"\xff"
asn1string = backend._lib.ASN1_STRING_type_new(self.V_ASN1_UTF8STRING)
backend._lib.ASN1_STRING_set(asn1string, encoded, len(encoded))
self.assertRaises(errors.ASN1StringError, utils.asn1_string_to_utf8,
asn1string)
class TestASN1Time(unittest.TestCase): class TestASN1Time(unittest.TestCase):
def test_conversion_failure(self): def test_round_check(self):
with mock.patch.object(backend._lib, "ASN1_TIME_to_generalizedtime", t = 0
return_value=backend._ffi.NULL): asn1_time = utils.timestamp_to_asn1_time(t)
t = utils.timestamp_to_asn1_time(0) res = utils.asn1_time_to_timestamp(asn1_time)
self.assertRaises(errors.ASN1TimeError, self.assertEqual(t, res)
utils.asn1_time_to_timestamp, t)
def test_generalizedtime_check_failure(self): def test_post_2050(self):
with mock.patch.object(backend._lib, "ASN1_GENERALIZEDTIME_check", """Test date post 2050, which causes different encoding."""
return_value=0): t = 2600000000
self.assertRaises(errors.ASN1TimeError, asn1_time = utils.timestamp_to_asn1_time(t)
utils.timestamp_to_asn1_time, 0) res = utils.asn1_time_to_timestamp(asn1_time)
self.assertEqual(t, res)
class TestTimezone(unittest.TestCase):
def test_utcoffset(self):
tz = utils.create_timezone(1234)
offset = tz.utcoffset(datetime.datetime.now())
self.assertEqual(datetime.timedelta(minutes=1234), offset)
def test_dst(self):
tz = utils.create_timezone(1234)
offset = tz.dst(datetime.datetime.now())
self.assertEqual(datetime.timedelta(0), offset)
def test_name(self):
tz = utils.create_timezone(1234)
name = tz.tzname(datetime.datetime.now())
self.assertIsNone(name)
def test_repr(self):
tz = utils.create_timezone(1234)
self.assertEqual("Timezone +2034", repr(tz))

View File

@ -18,25 +18,18 @@ import unittest
import mock import mock
import sys import io
import textwrap import textwrap
from anchor.X509 import certificate from anchor.X509 import certificate
from anchor.X509 import errors as x509_errors from anchor.X509 import errors as x509_errors
from anchor.X509 import extension
from anchor.X509 import name as x509_name from anchor.X509 import name as x509_name
from anchor.X509 import utils
# find the class representing an open file; it depends on the python version
# it's used later for mocking
if sys.version_info[0] < 3:
file_class = file # noqa
else:
import _io
file_class = _io.TextIOWrapper
class TestX509Cert(unittest.TestCase): class TestX509Cert(unittest.TestCase):
cert_data = textwrap.dedent(""" cert_data = textwrap.dedent(u"""
-----BEGIN CERTIFICATE----- -----BEGIN CERTIFICATE-----
MIICKjCCAZOgAwIBAgIIfeW6dwGe6wMwDQYJKoZIhvcNAQEFBQAwUjELMAkGA1UE MIICKjCCAZOgAwIBAgIIfeW6dwGe6wMwDQYJKoZIhvcNAQEFBQAwUjELMAkGA1UE
BhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxFjAUBgNVBAoTDUhlcnAgRGVycCBw BhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxFjAUBgNVBAoTDUhlcnAgRGVycCBw
@ -52,17 +45,70 @@ class TestX509Cert(unittest.TestCase):
gTLni27WuVJFVBNoTU1JfoxBSm/RBLdTj92g9N5g gTLni27WuVJFVBNoTU1JfoxBSm/RBLdTj92g9N5g
-----END CERTIFICATE-----""") -----END CERTIFICATE-----""")
key_dsa_data = textwrap.dedent("""
-----BEGIN DSA PARAMETERS-----
MIICLAKCAQEA59W1OsK9Tv7DRbxzibGVpBAL2Oz8JhbV3ii7WAat+UfTBLAnfdva
7UE8odu1l8p41N/8H/tDWgPh6tOgdX0YT9HDsILymQxzUEscliFZKmYg7YdSH3Zd
6DglOT7CqYxX0r9gK/BOh8ESe3gqKncnThHnO8Eu9wP8HNcrN00EOqP+fJpbS0lu
iifD9JdFY5YpCsLDIvpPbM0NCDuANPo10N3qqC8BuNiu0VfZpRSBcqzU1kwABT5n
y7+8RMh5Xaa7xnhGctJ9s9n+QfWcF/vbgiDOBttb3d8r8Pqvoou8v7Q38Q6zILhf
hajevqjGqZwodbvbHGfFbWapgBjpBIr4zwIhAOq6uryEHQglirWCGFJLQlkzxghy
ctHBRXGuKYb+ltRTAoIBAHRUFxzd1vhjKQ5atIdG0AiXUNm7/uboe21EJDLf4lkE
7UHDZfwsHXxQHfozzIsp7gHcw7F6AVCgiNRi9vBYOemPswevoWiVKqLTVt1wMogD
EJI6VAQEbBmSrtvyuClCkEAlIY6daX9EV9KqbnetS4/xv4WFQ9FPE47VyQ50vvxK
JSyNZnJ1lN6FUD9R5YYfwERgND8EYJBD10UBKIvtORICTJUfaDAweTWhaVcXUID7
VGNGPauOdVQzWsWTrQn/f/hbXCB/KXgv1l92D6rEoT2j2YrqIv/qD/ZxPwhBfLdr
W241Cb+LT05LVCokRbWUdjfuO8SdSBAIvT9P6umG/uQ=
-----END DSA PARAMETERS-----
-----BEGIN DSA PRIVATE KEY-----
MIIDVwIBAAKCAQEA59W1OsK9Tv7DRbxzibGVpBAL2Oz8JhbV3ii7WAat+UfTBLAn
fdva7UE8odu1l8p41N/8H/tDWgPh6tOgdX0YT9HDsILymQxzUEscliFZKmYg7YdS
H3Zd6DglOT7CqYxX0r9gK/BOh8ESe3gqKncnThHnO8Eu9wP8HNcrN00EOqP+fJpb
S0luiifD9JdFY5YpCsLDIvpPbM0NCDuANPo10N3qqC8BuNiu0VfZpRSBcqzU1kwA
BT5ny7+8RMh5Xaa7xnhGctJ9s9n+QfWcF/vbgiDOBttb3d8r8Pqvoou8v7Q38Q6z
ILhfhajevqjGqZwodbvbHGfFbWapgBjpBIr4zwIhAOq6uryEHQglirWCGFJLQlkz
xghyctHBRXGuKYb+ltRTAoIBAHRUFxzd1vhjKQ5atIdG0AiXUNm7/uboe21EJDLf
4lkE7UHDZfwsHXxQHfozzIsp7gHcw7F6AVCgiNRi9vBYOemPswevoWiVKqLTVt1w
MogDEJI6VAQEbBmSrtvyuClCkEAlIY6daX9EV9KqbnetS4/xv4WFQ9FPE47VyQ50
vvxKJSyNZnJ1lN6FUD9R5YYfwERgND8EYJBD10UBKIvtORICTJUfaDAweTWhaVcX
UID7VGNGPauOdVQzWsWTrQn/f/hbXCB/KXgv1l92D6rEoT2j2YrqIv/qD/ZxPwhB
fLdrW241Cb+LT05LVCokRbWUdjfuO8SdSBAIvT9P6umG/uQCggEBAKrZAppbnKf1
pzSvE3gTaloitAJG+79BML5h1n67EWuv0i+Fq4eUAVJ23R8GR1HrYw6utZoYbu8u
k8eHrArMfTfbFaLwK/Nv33Hfm3aTTXnY6auLNkpbiZXuCQjWBFhb6F+B42V9/JJ8
RJ1UV6Y2ajjjMvpeh0cPlARw5UpKBgQ933DhefCWyFBPsPToFvd3uPO+GUN6VpNY
iR7G0AH3/LSVJRuz5/QCp86uLIoU3fBEf1KGYJrkVKlc9DtcNmDXgpP0d3fK+4Jw
bGvi5AD1sQOWryNujyS/d2K/PAagsD0M6XJFgkEV592OSlygbYtuo3t4AtAy8F0f
VHNXq2l01FMCIQCrkk1749eQg4W6j7HfLFvjbDcuIFTw98IKyEZuZ93cdA==
-----END DSA PRIVATE KEY-----""").encode('ascii')
key_rsa_data = textwrap.dedent("""
-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQCeeqg1Qeccv8hqj1BP9KEJX5QsFCxR62M8plPb5t4sLo8UYfZd
6kFLcOP8xzwwvx/eFY6Sux52enQ197o8aMwyP77hMhZqtd8NCgLJMVlUbRhwLti0
SkHFPic0wAg+esfXa6yhd5TxC+bti7MgV/ljA80XQxHH8xOjdOoGN0DHfQIDAQAB
AoGBAJ2ozJpe+7qgGJPaCz3f0izvBwtq7kR49fqqRZbo8HHnx7OxWVVI7LhOkKEy
2/Bq0xsvOu1CdiXL4LynvIDIiQqLaeINzG48Rbk+0HadbXblt3nDkIWdYII6zHKI
W9ewX4KpHEPbrlEO9BjAlAcYsDIvFIMYpQhtQ+0R/gmZ99WJAkEAz5C2a6FIcMbE
o3aTc9ECq99zY7lxh+6aLpUdIeeHyb/QzfGDBdlbpBAkA6EcxSqp0aqH4xIQnYHa
3P5ZCShqSwJBAMN1sb76xq94xkg2cxShPFPAE6xKRFyKqLgsBYVtulOdfOtOnjh9
1SK2XQQfBRIRdG4Q/gDoCP8XQHpJcWMk+FcCQDnuJqulaOVo5GrG5mJ1nCxCAh98
G06X7lo/7dCPoRtSuMExvaK9RlFk29hTeAcjYCAPWzupyA9dtarmJg1jRT8CQCKf
gYnb8D/6+9yk0IPR/9ayCooVacCeyz48hgnZowzWs98WwQ4utAd/GED3obVOpDov
Bl9wus889i3zPoOac+cCQCZHredQcJGd4dlthbVtP2NhuPXz33JuETGR9pXtsDUZ
uX/nSq1oo9kUh/dPOz6aP5Ues1YVe3LExmExPBQfwIE=
-----END RSA PRIVATE KEY-----""").encode('ascii')
def setUp(self): def setUp(self):
super(TestX509Cert, self).setUp() super(TestX509Cert, self).setUp()
self.cert = certificate.X509Certificate() self.cert = certificate.X509Certificate.from_buffer(
self.cert.from_buffer(TestX509Cert.cert_data) TestX509Cert.cert_data)
def tearDown(self): def tearDown(self):
pass pass
def test_bad_data_throws(self): def test_bad_data_throws(self):
bad_data = ( bad_data = (
"some bad data is " u"some bad data is "
"EHRlc3RAYW5jaG9yLnRlc3QwTDANBgkqhkiG9w0BAQEFAAM7ADA4AjEA6m") "EHRlc3RAYW5jaG9yLnRlc3QwTDANBgkqhkiG9w0BAQEFAAM7ADA4AjEA6m")
cert = certificate.X509Certificate() cert = certificate.X509Certificate()
@ -72,119 +118,121 @@ class TestX509Cert(unittest.TestCase):
def test_get_subject_countryName(self): def test_get_subject_countryName(self):
name = self.cert.get_subject() name = self.cert.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_countryName) entries = name.get_entries_by_oid(x509_name.OID_countryName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "countryName") self.assertEqual(entries[0].get_name(), "countryName")
self.assertEqual(entries[0].get_value(), "UK") self.assertEqual(entries[0].get_value(), "UK")
def test_get_subject_stateOrProvinceName(self): def test_get_subject_stateOrProvinceName(self):
name = self.cert.get_subject() name = self.cert.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_stateOrProvinceName) entries = name.get_entries_by_oid(x509_name.OID_stateOrProvinceName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "stateOrProvinceName") self.assertEqual(entries[0].get_name(), "stateOrProvinceName")
self.assertEqual(entries[0].get_value(), "Narnia") self.assertEqual(entries[0].get_value(), "Narnia")
def test_get_subject_localityName(self): def test_get_subject_localityName(self):
name = self.cert.get_subject() name = self.cert.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_localityName) entries = name.get_entries_by_oid(x509_name.OID_localityName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "localityName") self.assertEqual(entries[0].get_name(), "localityName")
self.assertEqual(entries[0].get_value(), "Funkytown") self.assertEqual(entries[0].get_value(), "Funkytown")
def test_get_subject_organizationName(self): def test_get_subject_organizationName(self):
name = self.cert.get_subject() name = self.cert.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_organizationName) entries = name.get_entries_by_oid(x509_name.OID_organizationName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "organizationName") self.assertEqual(entries[0].get_name(), "organizationName")
self.assertEqual(entries[0].get_value(), "Anchor Testing") self.assertEqual(entries[0].get_value(), "Anchor Testing")
def test_get_subject_organizationUnitName(self): def test_get_subject_organizationUnitName(self):
name = self.cert.get_subject() name = self.cert.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_organizationalUnitName) entries = name.get_entries_by_oid(x509_name.OID_organizationalUnitName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "organizationalUnitName") self.assertEqual(entries[0].get_name(), "organizationalUnitName")
self.assertEqual(entries[0].get_value(), "testing") self.assertEqual(entries[0].get_value(), "testing")
def test_get_subject_commonName(self): def test_get_subject_commonName(self):
name = self.cert.get_subject() name = self.cert.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_commonName) entries = name.get_entries_by_oid(x509_name.OID_commonName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "commonName") self.assertEqual(entries[0].get_name(), "commonName")
self.assertEqual(entries[0].get_value(), "anchor.test") self.assertEqual(entries[0].get_value(), "anchor.test")
def test_get_subject_emailAddress(self): def test_get_subject_emailAddress(self):
name = self.cert.get_subject() name = self.cert.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_pkcs9_emailAddress) entries = name.get_entries_by_oid(x509_name.OID_pkcs9_emailAddress)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "emailAddress") self.assertEqual(entries[0].get_name(), "emailAddress")
self.assertEqual(entries[0].get_value(), "test@anchor.test") self.assertEqual(entries[0].get_value(), "test@anchor.test")
def test_get_issuer_countryName(self): def test_get_issuer_countryName(self):
name = self.cert.get_issuer() name = self.cert.get_issuer()
entries = name.get_entries_by_nid(x509_name.NID_countryName) entries = name.get_entries_by_oid(x509_name.OID_countryName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "countryName") self.assertEqual(entries[0].get_name(), "countryName")
self.assertEqual(entries[0].get_value(), "AU") self.assertEqual(entries[0].get_value(), "AU")
def test_get_issuer_stateOrProvinceName(self): def test_get_issuer_stateOrProvinceName(self):
name = self.cert.get_issuer() name = self.cert.get_issuer()
entries = name.get_entries_by_nid(x509_name.NID_stateOrProvinceName) entries = name.get_entries_by_oid(x509_name.OID_stateOrProvinceName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "stateOrProvinceName") self.assertEqual(entries[0].get_name(), "stateOrProvinceName")
self.assertEqual(entries[0].get_value(), "Some-State") self.assertEqual(entries[0].get_value(), "Some-State")
def test_get_issuer_organizationName(self): def test_get_issuer_organizationName(self):
name = self.cert.get_issuer() name = self.cert.get_issuer()
entries = name.get_entries_by_nid(x509_name.NID_organizationName) entries = name.get_entries_by_oid(x509_name.OID_organizationName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "organizationName") self.assertEqual(entries[0].get_name(), "organizationName")
self.assertEqual(entries[0].get_value(), "Herp Derp plc") self.assertEqual(entries[0].get_value(), "Herp Derp plc")
def test_get_issuer_commonName(self): def test_get_issuer_commonName(self):
name = self.cert.get_issuer() name = self.cert.get_issuer()
entries = name.get_entries_by_nid(x509_name.NID_commonName) entries = name.get_entries_by_oid(x509_name.OID_commonName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "commonName") self.assertEqual(entries[0].get_name(), "commonName")
self.assertEqual(entries[0].get_value(), "herp.derp.plc") self.assertEqual(entries[0].get_value(), "herp.derp.plc")
def test_set_subject(self): def test_set_subject(self):
name = x509_name.X509Name() name = x509_name.X509Name()
name.add_name_entry(x509_name.NID_countryName, 'UK') name.add_name_entry(x509_name.OID_countryName, 'UK')
self.cert.set_subject(name) self.cert.set_subject(name)
name = self.cert.get_subject() name = self.cert.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_countryName) entries = name.get_entries_by_oid(x509_name.OID_countryName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "countryName") self.assertEqual(entries[0].get_name(), "countryName")
self.assertEqual(entries[0].get_value(), "UK") self.assertEqual(entries[0].get_value(), "UK")
def test_set_issuer(self): def test_set_issuer(self):
name = x509_name.X509Name() name = x509_name.X509Name()
name.add_name_entry(x509_name.NID_countryName, 'UK') name.add_name_entry(x509_name.OID_countryName, 'UK')
self.cert.set_issuer(name) self.cert.set_issuer(name)
name = self.cert.get_issuer() name = self.cert.get_issuer()
entries = name.get_entries_by_nid(x509_name.NID_countryName) entries = name.get_entries_by_oid(x509_name.OID_countryName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "countryName") self.assertEqual(entries[0].get_name(), "countryName")
self.assertEqual(entries[0].get_value(), "UK") self.assertEqual(entries[0].get_value(), "UK")
def test_read_from_file(self): def test_read_from_file(self):
open_name = 'anchor.X509.certificate.open' open_name = 'anchor.X509.certificate.open'
f = io.StringIO(TestX509Cert.cert_data)
with mock.patch(open_name, create=True) as mock_open: with mock.patch(open_name, create=True) as mock_open:
mock_open.return_value = mock.MagicMock(spec=file_class) mock_open.return_value = f
m_file = mock_open.return_value.__enter__.return_value
m_file.read.return_value = TestX509Cert.cert_data
cert = certificate.X509Certificate() cert = certificate.X509Certificate.from_file("some_path")
cert.from_file("some_path")
name = cert.get_subject() name = cert.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_countryName) entries = name.get_entries_by_oid(x509_name.OID_countryName)
self.assertEqual(entries[0].get_value(), "UK") self.assertEqual(entries[0].get_value(), "UK")
def test_get_fingerprint(self): def test_get_fingerprint(self):
fp = self.cert.get_fingerprint() fp = self.cert.get_fingerprint()
self.assertEqual(fp, "56D61AC583BDDD4B44EEB479EF6C998F") self.assertEqual(fp, "634A8CD10C81F1CD7A7E140921B4D9CA")
def test_get_fingerprint_invalid_hash(self):
with self.assertRaises(x509_errors.X509Error):
self.cert.get_fingerprint('no_such_hash')
def test_sign_bad_md(self): def test_sign_bad_md(self):
self.assertRaises(x509_errors.X509Error, self.assertRaises(x509_errors.X509Error,
@ -194,7 +242,7 @@ class TestX509Cert(unittest.TestCase):
def test_sign_bad_key(self): def test_sign_bad_key(self):
self.assertRaises(x509_errors.X509Error, self.assertRaises(x509_errors.X509Error,
self.cert.sign, self.cert.sign,
self.cert._ffi.NULL) None)
def test_get_version(self): def test_get_version(self):
v = self.cert.get_version() v = self.cert.get_version()
@ -222,3 +270,40 @@ class TestX509Cert(unittest.TestCase):
self.cert.set_not_after(0) # seconds since epoch self.cert.set_not_after(0) # seconds since epoch
val = self.cert.get_not_after() val = self.cert.get_not_after()
self.assertEqual(0, val) self.assertEqual(0, val)
def test_get_extensions(self):
exts = self.cert.get_extensions()
self.assertEqual(2, len(exts))
def test_add_extensions(self):
bc = extension.X509ExtensionBasicConstraints()
self.cert.add_extension(bc, 2)
exts = self.cert.get_extensions()
self.assertEqual(3, len(exts))
def test_add_extensions_invalid(self):
with self.assertRaises(x509_errors.X509Error):
self.cert.add_extension("abcdef", 2)
def test_sign_rsa_sha1(self):
key = utils.get_private_key_from_bytes(self.key_rsa_data)
self.cert.sign(key, 'sha1')
self.assertEqual(self.cert.get_fingerprint(),
"BA1B5C97D68EAE738FD10657E6F0B143")
def test_sign_dsa_sha1(self):
key = utils.get_private_key_from_bytes(self.key_dsa_data)
self.cert.sign(key, 'sha1')
# TODO(stan): add verification; DSA signatures are not
# deterministic which means right now we can only make sure it
# doesn't raise exceptions
def test_sign_unknown_key(self):
key = object()
with self.assertRaises(x509_errors.X509Error):
self.cert.sign(key, 'sha1')
def test_sign_unknown_hash(self):
key = utils.get_private_key_from_bytes(self.key_rsa_data)
with self.assertRaises(x509_errors.X509Error):
self.cert.sign(key, 'no_such_hash')

View File

@ -14,29 +14,21 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import sys import io
import textwrap import textwrap
import unittest import unittest
from cryptography.hazmat.backends.openssl import backend
import mock import mock
from pyasn1_modules import rfc2459
from anchor.X509 import errors as x509_errors from anchor.X509 import errors as x509_errors
from anchor.X509 import extension
from anchor.X509 import name as x509_name from anchor.X509 import name as x509_name
from anchor.X509 import signing_request from anchor.X509 import signing_request
# find the class representing an open file; it depends on the python version
# it's used later for mocking
if sys.version_info[0] < 3:
file_class = file # noqa
else:
import _io
file_class = _io.TextIOWrapper
class TestX509Csr(unittest.TestCase): class TestX509Csr(unittest.TestCase):
csr_data = textwrap.dedent(""" csr_data = textwrap.dedent(u"""
-----BEGIN CERTIFICATE REQUEST----- -----BEGIN CERTIFICATE REQUEST-----
MIIBWTCCARMCAQAwgZQxCzAJBgNVBAYTAlVLMQ8wDQYDVQQIEwZOYXJuaWExEjAQ MIIBWTCCARMCAQAwgZQxCzAJBgNVBAYTAlVLMQ8wDQYDVQQIEwZOYXJuaWExEjAQ
BgNVBAcTCUZ1bmt5dG93bjEXMBUGA1UEChMOQW5jaG9yIFRlc3RpbmcxEDAOBgNV BgNVBAcTCUZ1bmt5dG93bjEXMBUGA1UEChMOQW5jaG9yIFRlc3RpbmcxEDAOBgNV
@ -50,41 +42,53 @@ class TestX509Csr(unittest.TestCase):
def setUp(self): def setUp(self):
super(TestX509Csr, self).setUp() super(TestX509Csr, self).setUp()
self.csr = signing_request.X509Csr() self.csr = signing_request.X509Csr.from_buffer(TestX509Csr.csr_data)
self.csr.from_buffer(TestX509Csr.csr_data)
def tearDown(self): def tearDown(self):
pass pass
def test_get_pubkey_bits(self): def test_get_pubkey(self):
# some OpenSSL gumph to test a reasonable attribute of the pubkey
pubkey = self.csr.get_pubkey() pubkey = self.csr.get_pubkey()
size = backend._lib.EVP_PKEY_bits(pubkey) self.assertEqual(pubkey['algorithm']['algorithm'],
self.assertEqual(size, 384) rfc2459.rsaEncryption)
def test_get_extensions(self): def test_get_extensions(self):
exts = self.csr.get_extensions() exts = self.csr.get_extensions()
self.assertEqual(len(exts), 2) self.assertEqual(len(exts), 2)
self.assertEqual(str(exts[0]), "basicConstraints CA:FALSE") self.assertFalse(exts[0].get_ca())
self.assertEqual(str(exts[1]), ("keyUsage Digital Signature, Non " self.assertIsNone(exts[0].get_path_len_constraint())
"Repudiation, Key Encipherment")) self.assertTrue(exts[1].get_usage('digitalSignature'))
self.assertTrue(exts[1].get_usage('nonRepudiation'))
self.assertTrue(exts[1].get_usage('keyEncipherment'))
self.assertFalse(exts[1].get_usage('cRLSign'))
def test_add_extension(self):
csr = signing_request.X509Csr()
bc = extension.X509ExtensionBasicConstraints()
csr.add_extension(bc)
self.assertEqual(1, len(csr.get_extensions()))
csr.add_extension(bc)
self.assertEqual(2, len(csr.get_extensions()))
def test_add_extension_invalid_type(self):
csr = signing_request.X509Csr()
with self.assertRaises(x509_errors.X509Error):
csr.add_extension(1234)
def test_read_from_file(self): def test_read_from_file(self):
open_name = 'anchor.X509.signing_request.open' open_name = 'anchor.X509.signing_request.open'
f = io.StringIO(TestX509Csr.csr_data)
with mock.patch(open_name, create=True) as mock_open: with mock.patch(open_name, create=True) as mock_open:
mock_open.return_value = mock.MagicMock(spec=file_class) mock_open.return_value = f
m_file = mock_open.return_value.__enter__.return_value csr = signing_request.X509Csr.from_file("some_path")
m_file.read.return_value = TestX509Csr.csr_data
csr = signing_request.X509Csr()
csr.from_file("some_path")
name = csr.get_subject() name = csr.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_countryName) entries = name.get_entries_by_oid(x509_name.OID_countryName)
self.assertEqual(entries[0].get_value(), "UK") self.assertEqual(entries[0].get_value(), "UK")
def test_bad_data_throws(self): def test_bad_data_throws(self):
bad_data = ( bad_data = (
"some bad data is " u"some bad data is "
"EHRlc3RAYW5jaG9yLnRlc3QwTDANBgkqhkiG9w0BAQEFAAM7ADA4AjEA6m") "EHRlc3RAYW5jaG9yLnRlc3QwTDANBgkqhkiG9w0BAQEFAAM7ADA4AjEA6m")
csr = signing_request.X509Csr() csr = signing_request.X509Csr()
@ -94,49 +98,49 @@ class TestX509Csr(unittest.TestCase):
def test_get_subject_countryName(self): def test_get_subject_countryName(self):
name = self.csr.get_subject() name = self.csr.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_countryName) entries = name.get_entries_by_oid(x509_name.OID_countryName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "countryName") self.assertEqual(entries[0].get_name(), "countryName")
self.assertEqual(entries[0].get_value(), "UK") self.assertEqual(entries[0].get_value(), "UK")
def test_get_subject_stateOrProvinceName(self): def test_get_subject_stateOrProvinceName(self):
name = self.csr.get_subject() name = self.csr.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_stateOrProvinceName) entries = name.get_entries_by_oid(x509_name.OID_stateOrProvinceName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "stateOrProvinceName") self.assertEqual(entries[0].get_name(), "stateOrProvinceName")
self.assertEqual(entries[0].get_value(), "Narnia") self.assertEqual(entries[0].get_value(), "Narnia")
def test_get_subject_localityName(self): def test_get_subject_localityName(self):
name = self.csr.get_subject() name = self.csr.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_localityName) entries = name.get_entries_by_oid(x509_name.OID_localityName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "localityName") self.assertEqual(entries[0].get_name(), "localityName")
self.assertEqual(entries[0].get_value(), "Funkytown") self.assertEqual(entries[0].get_value(), "Funkytown")
def test_get_subject_organizationName(self): def test_get_subject_organizationName(self):
name = self.csr.get_subject() name = self.csr.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_organizationName) entries = name.get_entries_by_oid(x509_name.OID_organizationName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "organizationName") self.assertEqual(entries[0].get_name(), "organizationName")
self.assertEqual(entries[0].get_value(), "Anchor Testing") self.assertEqual(entries[0].get_value(), "Anchor Testing")
def test_get_subject_organizationUnitName(self): def test_get_subject_organizationUnitName(self):
name = self.csr.get_subject() name = self.csr.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_organizationalUnitName) entries = name.get_entries_by_oid(x509_name.OID_organizationalUnitName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "organizationalUnitName") self.assertEqual(entries[0].get_name(), "organizationalUnitName")
self.assertEqual(entries[0].get_value(), "testing") self.assertEqual(entries[0].get_value(), "testing")
def test_get_subject_commonName(self): def test_get_subject_commonName(self):
name = self.csr.get_subject() name = self.csr.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_commonName) entries = name.get_entries_by_oid(x509_name.OID_commonName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "commonName") self.assertEqual(entries[0].get_name(), "commonName")
self.assertEqual(entries[0].get_value(), "anchor.test") self.assertEqual(entries[0].get_value(), "anchor.test")
def test_get_subject_emailAddress(self): def test_get_subject_emailAddress(self):
name = self.csr.get_subject() name = self.csr.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_pkcs9_emailAddress) entries = name.get_entries_by_oid(x509_name.OID_pkcs9_emailAddress)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "emailAddress") self.assertEqual(entries[0].get_name(), "emailAddress")
self.assertEqual(entries[0].get_value(), "test@anchor.test") self.assertEqual(entries[0].get_value(), "test@anchor.test")

View File

@ -24,18 +24,18 @@ class TestX509Name(unittest.TestCase):
def setUp(self): def setUp(self):
super(TestX509Name, self).setUp() super(TestX509Name, self).setUp()
self.name = x509_name.X509Name() self.name = x509_name.X509Name()
self.name.add_name_entry(x509_name.NID_countryName, self.name.add_name_entry(x509_name.OID_countryName,
"UK") # must be 2 chars "UK") # must be 2 chars
self.name.add_name_entry(x509_name.NID_stateOrProvinceName, "test_ST") self.name.add_name_entry(x509_name.OID_stateOrProvinceName, "test_ST")
self.name.add_name_entry(x509_name.NID_localityName, "test_L") self.name.add_name_entry(x509_name.OID_localityName, "test_L")
self.name.add_name_entry(x509_name.NID_organizationName, "test_O") self.name.add_name_entry(x509_name.OID_organizationName, "test_O")
self.name.add_name_entry(x509_name.NID_organizationalUnitName, self.name.add_name_entry(x509_name.OID_organizationalUnitName,
"test_OU") "test_OU")
self.name.add_name_entry(x509_name.NID_commonName, "test_CN") self.name.add_name_entry(x509_name.OID_commonName, "test_CN")
self.name.add_name_entry(x509_name.NID_pkcs9_emailAddress, self.name.add_name_entry(x509_name.OID_pkcs9_emailAddress,
"test_Email") "test_Email")
self.name.add_name_entry(x509_name.NID_surname, "test_SN") self.name.add_name_entry(x509_name.OID_surname, "test_SN")
self.name.add_name_entry(x509_name.NID_givenName, "test_GN") self.name.add_name_entry(x509_name.OID_givenName, "test_GN")
def tearDown(self): def tearDown(self):
pass pass
@ -48,7 +48,7 @@ class TestX509Name(unittest.TestCase):
def test_set_bad_c_throws(self): def test_set_bad_c_throws(self):
self.assertRaises(x509_errors.X509Error, self.assertRaises(x509_errors.X509Error,
self.name.add_name_entry, self.name.add_name_entry,
x509_name.NID_countryName, "BAD_WRONG") x509_name.OID_countryName, "BAD_WRONG")
def test_name_to_string(self): def test_name_to_string(self):
val = str(self.name) val = str(self.name)
@ -57,53 +57,53 @@ class TestX509Name(unittest.TestCase):
"SN=test_SN/GN=test_GN")) "SN=test_SN/GN=test_GN"))
def test_get_countryName(self): def test_get_countryName(self):
entries = self.name.get_entries_by_nid(x509_name.NID_countryName) entries = self.name.get_entries_by_oid(x509_name.OID_countryName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "countryName") self.assertEqual(entries[0].get_name(), "countryName")
self.assertEqual(entries[0].get_value(), "UK") self.assertEqual(entries[0].get_value(), "UK")
def test_get_stateOrProvinceName(self): def test_get_stateOrProvinceName(self):
entries = self.name.get_entries_by_nid( entries = self.name.get_entries_by_oid(
x509_name.NID_stateOrProvinceName) x509_name.OID_stateOrProvinceName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "stateOrProvinceName") self.assertEqual(entries[0].get_name(), "stateOrProvinceName")
self.assertEqual(entries[0].get_value(), "test_ST") self.assertEqual(entries[0].get_value(), "test_ST")
def test_get_subject_localityName(self): def test_get_subject_localityName(self):
entries = self.name.get_entries_by_nid(x509_name.NID_localityName) entries = self.name.get_entries_by_oid(x509_name.OID_localityName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "localityName") self.assertEqual(entries[0].get_name(), "localityName")
self.assertEqual(entries[0].get_value(), "test_L") self.assertEqual(entries[0].get_value(), "test_L")
def test_get_organizationName(self): def test_get_organizationName(self):
entries = self.name.get_entries_by_nid(x509_name.NID_organizationName) entries = self.name.get_entries_by_oid(x509_name.OID_organizationName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "organizationName") self.assertEqual(entries[0].get_name(), "organizationName")
self.assertEqual(entries[0].get_value(), "test_O") self.assertEqual(entries[0].get_value(), "test_O")
def test_get_organizationUnitName(self): def test_get_organizationUnitName(self):
entries = self.name.get_entries_by_nid( entries = self.name.get_entries_by_oid(
x509_name.NID_organizationalUnitName) x509_name.OID_organizationalUnitName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "organizationalUnitName") self.assertEqual(entries[0].get_name(), "organizationalUnitName")
self.assertEqual(entries[0].get_value(), "test_OU") self.assertEqual(entries[0].get_value(), "test_OU")
def test_get_commonName(self): def test_get_commonName(self):
entries = self.name.get_entries_by_nid(x509_name.NID_commonName) entries = self.name.get_entries_by_oid(x509_name.OID_commonName)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "commonName") self.assertEqual(entries[0].get_name(), "commonName")
self.assertEqual(entries[0].get_value(), "test_CN") self.assertEqual(entries[0].get_value(), "test_CN")
def test_get_emailAddress(self): def test_get_emailAddress(self):
entries = self.name.get_entries_by_nid( entries = self.name.get_entries_by_oid(
x509_name.NID_pkcs9_emailAddress) x509_name.OID_pkcs9_emailAddress)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "emailAddress") self.assertEqual(entries[0].get_name(), "emailAddress")
self.assertEqual(entries[0].get_value(), "test_Email") self.assertEqual(entries[0].get_value(), "test_Email")
def test_entry_to_string(self): def test_entry_to_string(self):
entries = self.name.get_entries_by_nid( entries = self.name.get_entries_by_oid(
x509_name.NID_pkcs9_emailAddress) x509_name.OID_pkcs9_emailAddress)
self.assertEqual(len(entries), 1) self.assertEqual(len(entries), 1)
self.assertEqual(str(entries[0]), "emailAddress: test_Email") self.assertEqual(str(entries[0]), "emailAddress: test_Email")

View File

@ -29,7 +29,7 @@ class CertificateOpsTests(unittest.TestCase):
def setUp(self): def setUp(self):
# This is a CSR with CN=anchor-test.example.com # This is a CSR with CN=anchor-test.example.com
self.expected_cn = "anchor-test.example.com" self.expected_cn = "anchor-test.example.com"
self.csr = textwrap.dedent(""" self.csr = textwrap.dedent(u"""
-----BEGIN CERTIFICATE REQUEST----- -----BEGIN CERTIFICATE REQUEST-----
MIIEsDCCApgCAQAwazELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWEx MIIEsDCCApgCAQAwazELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWEx
FjAUBgNVBAcTDU1vdW50YWluIFZpZXcxDTALBgNVBAoTBEFjbWUxIDAeBgNVBAMT FjAUBgNVBAcTDU1vdW50YWluIFZpZXcxDTALBgNVBAoTBEFjbWUxIDAeBgNVBAMT
@ -67,16 +67,16 @@ class CertificateOpsTests(unittest.TestCase):
"""Test basic success path for parse_csr.""" """Test basic success path for parse_csr."""
result = certificate_ops.parse_csr(self.csr, 'pem') result = certificate_ops.parse_csr(self.csr, 'pem')
subject = result.get_subject() subject = result.get_subject()
actual_cn = subject.get_entries_by_nid( actual_cn = subject.get_entries_by_oid(
x509_name.NID_commonName)[0].get_value() x509_name.OID_commonName)[0].get_value()
self.assertEqual(actual_cn, self.expected_cn) self.assertEqual(actual_cn, self.expected_cn)
def test_parse_csr_success2(self): def test_parse_csr_success2(self):
"""Test basic success path for parse_csr.""" """Test basic success path for parse_csr."""
result = certificate_ops.parse_csr(self.csr, 'PEM') result = certificate_ops.parse_csr(self.csr, 'PEM')
subject = result.get_subject() subject = result.get_subject()
actual_cn = subject.get_entries_by_nid( actual_cn = subject.get_entries_by_oid(
x509_name.NID_commonName)[0].get_value() x509_name.OID_commonName)[0].get_value()
self.assertEqual(actual_cn, self.expected_cn) self.assertEqual(actual_cn, self.expected_cn)
def test_parse_csr_fail1(self): def test_parse_csr_fail1(self):

View File

@ -58,7 +58,7 @@ class TestFunctional(unittest.TestCase):
} }
""" """
csr_good = textwrap.dedent(""" csr_good = textwrap.dedent(u"""
-----BEGIN CERTIFICATE REQUEST----- -----BEGIN CERTIFICATE REQUEST-----
MIIEDzCCAncCAQAwcjELMAkGA1UEBhMCR0IxEzARBgNVBAgTCkNhbGlmb3JuaWEx MIIEDzCCAncCAQAwcjELMAkGA1UEBhMCR0IxEzARBgNVBAgTCkNhbGlmb3JuaWEx
FjAUBgNVBAcTDVNhbiBGcmFuY3NpY28xDTALBgNVBAoTBE9TU0cxDTALBgNVBAsT FjAUBgNVBAcTDVNhbiBGcmFuY3NpY28xDTALBgNVBAoTBE9TU0cxDTALBgNVBAsT
@ -84,7 +84,7 @@ class TestFunctional(unittest.TestCase):
tR7XqQGqJKca/vRTfJ+zIAxMEeH1N9Lx7YBO6VdVja+yG1E= tR7XqQGqJKca/vRTfJ+zIAxMEeH1N9Lx7YBO6VdVja+yG1E=
-----END CERTIFICATE REQUEST-----""") -----END CERTIFICATE REQUEST-----""")
csr_bad = textwrap.dedent(""" csr_bad = textwrap.dedent(u"""
-----BEGIN CERTIFICATE REQUEST----- -----BEGIN CERTIFICATE REQUEST-----
MIIBWTCCARMCAQAwgZQxCzAJBgNVBAYTAlVLMQ8wDQYDVQQIEwZOYXJuaWExEjAQ MIIBWTCCARMCAQAwgZQxCzAJBgNVBAYTAlVLMQ8wDQYDVQQIEwZOYXJuaWExEjAQ
BgNVBAcTCUZ1bmt5dG93bjEXMBUGA1UEChMOQW5jaG9yIFRlc3RpbmcxEDAOBgNV BgNVBAcTCUZ1bmt5dG93bjEXMBUGA1UEChMOQW5jaG9yIFRlc3RpbmcxEDAOBgNV
@ -149,8 +149,7 @@ class TestFunctional(unittest.TestCase):
resp = self.app.post('/sign', data, expect_errors=False) resp = self.app.post('/sign', data, expect_errors=False)
self.assertEqual(200, resp.status_int) self.assertEqual(200, resp.status_int)
cert = X509_cert.X509Certificate() cert = X509_cert.X509Certificate.from_buffer(resp.text)
cert.from_buffer(resp.text)
# make sure the cert is what we asked for # make sure the cert is what we asked for
self.assertEqual(("/C=GB/ST=California/L=San Francsico/O=OSSG" self.assertEqual(("/C=GB/ST=California/L=San Francsico/O=OSSG"

View File

@ -24,7 +24,7 @@ from anchor.X509 import signing_request
class TestBaseValidators(unittest.TestCase): class TestBaseValidators(unittest.TestCase):
csr_data_with_cn = textwrap.dedent(""" csr_data_with_cn = textwrap.dedent(u"""
-----BEGIN CERTIFICATE REQUEST----- -----BEGIN CERTIFICATE REQUEST-----
MIIDBTCCAe0CAQAwgb8xCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlh MIIDBTCCAe0CAQAwgb8xCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlh
MRYwFAYDVQQHEw1TYW4gRnJhbmNpc2NvMSEwHwYDVQQKExhPcGVuU3RhY2sgU2Vj MRYwFAYDVQQHEw1TYW4gRnJhbmNpc2NvMSEwHwYDVQQKExhPcGVuU3RhY2sgU2Vj
@ -51,7 +51,7 @@ class TestBaseValidators(unittest.TestCase):
CN=ossg.test.com/emailAddress=openstack-security@lists.openstack.org CN=ossg.test.com/emailAddress=openstack-security@lists.openstack.org
""" """
csr_data_without_cn = textwrap.dedent(""" csr_data_without_cn = textwrap.dedent(u"""
-----BEGIN CERTIFICATE REQUEST----- -----BEGIN CERTIFICATE REQUEST-----
MIIC7TCCAdUCAQAwgacxCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlh MIIC7TCCAdUCAQAwgacxCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlh
MRYwFAYDVQQHDA1TYW4gRnJhbmNpc2NvMSEwHwYDVQQKDBhPcGVuU3RhY2sgU2Vj MRYwFAYDVQQHDA1TYW4gRnJhbmNpc2NvMSEwHwYDVQQKDBhPcGVuU3RhY2sgU2Vj
@ -79,8 +79,8 @@ class TestBaseValidators(unittest.TestCase):
def setUp(self): def setUp(self):
super(TestBaseValidators, self).setUp() super(TestBaseValidators, self).setUp()
self.csr = signing_request.X509Csr() self.csr = signing_request.X509Csr.from_buffer(
self.csr.from_buffer(TestBaseValidators.csr_data_with_cn) TestBaseValidators.csr_data_with_cn)
def tearDown(self): def tearDown(self):
super(TestBaseValidators, self).tearDown() super(TestBaseValidators, self).tearDown()
@ -89,7 +89,8 @@ class TestBaseValidators(unittest.TestCase):
name = validators.csr_get_cn(self.csr) name = validators.csr_get_cn(self.csr)
self.assertEqual(name, "ossg.test.com") self.assertEqual(name, "ossg.test.com")
self.csr.from_buffer(TestBaseValidators.csr_data_without_cn) self.csr = signing_request.X509Csr.from_buffer(
TestBaseValidators.csr_data_without_cn)
with self.assertRaises(validators.ValidationError): with self.assertRaises(validators.ValidationError):
validators.csr_get_cn(self.csr) validators.csr_get_cn(self.csr)

View File

@ -20,7 +20,9 @@ import mock
import netaddr import netaddr
from anchor import validators from anchor import validators
from anchor.X509 import extension as x509_ext
from anchor.X509 import name as x509_name from anchor.X509 import name as x509_name
from anchor.X509 import signing_request as x509_csr
class TestValidators(unittest.TestCase): class TestValidators(unittest.TestCase):
@ -45,262 +47,174 @@ class TestValidators(unittest.TestCase):
'example.com', [])) 'example.com', []))
def test_common_name_with_two_CN(self): def test_common_name_with_two_CN(self):
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_name.return_value = "subjectAltName" name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, "dummy_value")
csr_config = { name.add_name_entry(x509_name.OID_commonName, "dummy_value")
'get_extensions.return_value': [ext_mock],
'get_subject.return_value.get_entries_by_nid.return_value':
['dummy_value', 'dummy_value'],
}
csr_mock = mock.MagicMock(**csr_config)
with self.assertRaises(validators.ValidationError) as e: with self.assertRaises(validators.ValidationError) as e:
validators.common_name( validators.common_name(
csr=csr_mock, csr=csr,
allowed_domains=[], allowed_domains=[],
allowed_networks=[]) allowed_networks=[])
self.assertEqual("Too many CNs in the request", str(e.exception)) self.assertEqual("Too many CNs in the request", str(e.exception))
def test_common_name_no_CN(self): def test_common_name_no_CN(self):
csr_config = { csr = x509_csr.X509Csr()
'get_subject.return_value.__len__.return_value': 0,
'get_subject.return_value.get_entries_by_nid.return_value':
[]
}
csr_mock = mock.MagicMock(**csr_config)
with self.assertRaises(validators.ValidationError) as e: with self.assertRaises(validators.ValidationError) as e:
validators.common_name( validators.common_name(
csr=csr_mock, csr=csr,
allowed_domains=[], allowed_domains=[],
allowed_networks=[]) allowed_networks=[])
self.assertEqual("Alt subjects have to exist if the main subject" self.assertEqual("Alt subjects have to exist if the main subject"
" doesn't", str(e.exception)) " doesn't", str(e.exception))
def test_common_name_good_CN(self): def test_common_name_good_CN(self):
cn_mock = mock.MagicMock() csr = x509_csr.X509Csr()
cn_mock.get_value.return_value = 'master.test.com' name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, "master.test.com")
csr_config = {
'get_subject.return_value.__len__.return_value': 1,
'get_subject.return_value.get_entries_by_nid.return_value':
[cn_mock],
}
csr_mock = mock.MagicMock(**csr_config)
self.assertEqual( self.assertEqual(
None, None,
validators.common_name( validators.common_name(
csr=csr_mock, csr=csr,
allowed_domains=['.test.com'], allowed_domains=['.test.com'],
) )
) )
def test_common_name_bad_CN(self): def test_common_name_bad_CN(self):
name = x509_name.X509Name() csr = x509_csr.X509Csr()
name.add_name_entry(x509_name.NID_commonName, 'test.baddomain.com') name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, 'test.baddomain.com')
csr_mock = mock.MagicMock()
csr_mock.get_subject.return_value = name
with self.assertRaises(validators.ValidationError) as e: with self.assertRaises(validators.ValidationError) as e:
validators.common_name( validators.common_name(
csr=csr_mock, csr=csr,
allowed_domains=['.test.com']) allowed_domains=['.test.com'])
self.assertEqual("Domain 'test.baddomain.com' not allowed (does not " self.assertEqual("Domain 'test.baddomain.com' not allowed (does not "
"match known domains)", str(e.exception)) "match known domains)", str(e.exception))
def test_common_name_ip_good(self): def test_common_name_ip_good(self):
name = x509_name.X509Name() csr = x509_csr.X509Csr()
name.add_name_entry(x509_name.NID_commonName, '10.1.1.1') name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, '10.1.1.1')
csr_mock = mock.MagicMock()
csr_mock.get_subject.return_value = name
self.assertEqual( self.assertEqual(
None, None,
validators.common_name( validators.common_name(
csr=csr_mock, csr=csr,
allowed_domains=['.test.com'], allowed_domains=['.test.com'],
allowed_networks=['10/8'] allowed_networks=['10/8']
) )
) )
def test_common_name_ip_bad(self): def test_common_name_ip_bad(self):
name = x509_name.X509Name() csr = x509_csr.X509Csr()
name.add_name_entry(x509_name.NID_commonName, '15.1.1.1') name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, '15.1.1.1')
csr_mock = mock.MagicMock()
csr_mock.get_subject.return_value = name
with self.assertRaises(validators.ValidationError) as e: with self.assertRaises(validators.ValidationError) as e:
validators.common_name( validators.common_name(
csr=csr_mock, csr=csr,
allowed_domains=['.test.com'], allowed_domains=['.test.com'],
allowed_networks=['10/8']) allowed_networks=['10/8'])
self.assertEqual("Address '15.1.1.1' not allowed (does not " self.assertEqual("Address '15.1.1.1' not allowed (does not "
"match known networks)", str(e.exception)) "match known networks)", str(e.exception))
def test_alternative_names_good_domain(self): def test_alternative_names_good_domain(self):
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_value.return_value = 'DNS:master.test.com' ext = x509_ext.X509ExtensionSubjectAltName()
ext_mock.get_name.return_value = 'subjectAltName' ext.add_dns_id('master.test.com')
csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
self.assertEqual( self.assertEqual(
None, None,
validators.alternative_names( validators.alternative_names(
csr=csr_mock, csr=csr,
allowed_domains=['.test.com'], allowed_domains=['.test.com'],
) )
) )
def test_alternative_names_bad_domain(self): def test_alternative_names_bad_domain(self):
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_value.return_value = 'DNS:test.baddomain.com' ext = x509_ext.X509ExtensionSubjectAltName()
ext_mock.get_name.return_value = 'subjectAltName' ext.add_dns_id('test.baddomain.com')
csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e: with self.assertRaises(validators.ValidationError) as e:
validators.alternative_names( validators.alternative_names(
csr=csr_mock, csr=csr,
allowed_domains=['.test.com']) allowed_domains=['.test.com'])
self.assertEqual("Domain 'test.baddomain.com' not allowed (doesn't " self.assertEqual("Domain 'test.baddomain.com' not allowed (doesn't "
"match known domains)", str(e.exception)) "match known domains)", str(e.exception))
def test_alternative_names_ext(self):
ext_mock = mock.MagicMock()
ext_mock.get_value.return_value = 'BAD,10.1.1.1'
ext_mock.get_name.return_value = 'subjectAltName'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e:
validators.alternative_names(
csr=csr_mock,
allowed_domains=['.test.com'])
self.assertEqual("Alt name should have 2 parts, but found: 'BAD'",
str(e.exception))
def test_alternative_names_ip_good(self): def test_alternative_names_ip_good(self):
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_value.return_value = 'IP Address:10.1.1.1' ext = x509_ext.X509ExtensionSubjectAltName()
ext_mock.get_name.return_value = 'subjectAltName' ext.add_ip(netaddr.IPAddress('10.1.1.1'))
csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
self.assertEqual( self.assertEqual(
None, None,
validators.alternative_names_ip( validators.alternative_names_ip(
csr=csr_mock, csr=csr,
allowed_domains=['.test.com'], allowed_domains=['.test.com'],
allowed_networks=['10/8'] allowed_networks=['10/8']
) )
) )
def test_alternative_names_ip_bad(self): def test_alternative_names_ip_bad(self):
csr = x509_csr.X509Csr()
ext_mock = mock.MagicMock() ext = x509_ext.X509ExtensionSubjectAltName()
ext_mock.get_value.return_value = 'IP Address:10.1.1.1' ext.add_ip(netaddr.IPAddress('10.1.1.1'))
ext_mock.get_name.return_value = 'subjectAltName' csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e: with self.assertRaises(validators.ValidationError) as e:
validators.alternative_names_ip( validators.alternative_names_ip(
csr=csr_mock, csr=csr,
allowed_domains=['.test.com'], allowed_domains=['.test.com'],
allowed_networks=['99/8']) allowed_networks=['99/8'])
self.assertEqual("Address '10.1.1.1' not allowed (doesn't match known " self.assertEqual("IP '10.1.1.1' not allowed (doesn't match known "
"networks)", str(e.exception)) "networks)", str(e.exception))
def test_alternative_names_ip_bad_domain(self): def test_alternative_names_ip_bad_domain(self):
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_value.return_value = 'DNS:test.baddomain.com' ext = x509_ext.X509ExtensionSubjectAltName()
ext_mock.get_name.return_value = 'subjectAltName' ext.add_dns_id('test.baddomain.com')
csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e: with self.assertRaises(validators.ValidationError) as e:
validators.alternative_names_ip( validators.alternative_names_ip(
csr=csr_mock, csr=csr,
allowed_domains=['.test.com']) allowed_domains=['.test.com'])
self.assertEqual("Domain 'test.baddomain.com' not allowed (doesn't " self.assertEqual("Domain 'test.baddomain.com' not allowed (doesn't "
"match known domains)", str(e.exception)) "match known domains)", str(e.exception))
def test_alternative_names_ip_ext(self):
ext_mock = mock.MagicMock()
ext_mock.get_value.return_value = 'BAD,10.1.1.1'
ext_mock.get_name.return_value = 'subjectAltName'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e:
validators.alternative_names_ip(
csr=csr_mock,
allowed_domains=['.test.com'])
self.assertEqual("Alt name should have 2 parts, but found: 'BAD'",
str(e.exception))
def test_alternative_names_ip_bad_ext(self):
ext_mock = mock.MagicMock()
ext_mock.get_value.return_value = 'BAD:VALUE'
ext_mock.get_name.return_value = 'subjectAltName'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e:
validators.alternative_names_ip(
csr=csr_mock,
allowed_domains=['.test.com'],
allowed_networks=['99/8'])
self.assertEqual("Alt name 'VALUE' has unexpected type 'BAD'",
str(e.exception))
def test_server_group_no_prefix1(self): def test_server_group_no_prefix1(self):
cn_mock = mock.MagicMock() csr = x509_csr.X509Csr()
cn_mock.get_value.return_value = 'master.test.com' name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, "master.test.com")
csr_config = {
'get_subject.return_value.get_entries_by_nid.return_value':
[cn_mock],
}
csr_mock = mock.MagicMock(**csr_config)
self.assertEqual( self.assertEqual(
None, None,
validators.server_group( validators.server_group(
auth_result=None, auth_result=None,
csr=csr_mock, csr=csr,
group_prefixes={} group_prefixes={}
) )
) )
def test_server_group_no_prefix2(self): def test_server_group_no_prefix2(self):
cn_mock = mock.MagicMock() csr = x509_csr.X509Csr()
cn_mock.get_value.return_value = 'nv_master.test.com' name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, "nv_master.test.com")
csr_config = {
'get_subject.return_value.get_entries_by_nid.return_value':
[cn_mock],
}
csr_mock = mock.MagicMock(**csr_config)
self.assertEqual( self.assertEqual(
None, None,
validators.server_group( validators.server_group(
auth_result=None, auth_result=None,
csr=csr_mock, csr=csr,
group_prefixes={} group_prefixes={}
) )
) )
@ -310,20 +224,15 @@ class TestValidators(unittest.TestCase):
auth_result = mock.Mock() auth_result = mock.Mock()
auth_result.groups = ['nova'] auth_result.groups = ['nova']
cn_mock = mock.MagicMock() csr = x509_csr.X509Csr()
cn_mock.get_value.return_value = 'nv_master.test.com' name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, "nv_master.test.com")
csr_config = {
'get_subject.return_value.get_entries_by_nid.return_value':
[cn_mock],
}
csr_mock = mock.MagicMock(**csr_config)
self.assertEqual( self.assertEqual(
None, None,
validators.server_group( validators.server_group(
auth_result=auth_result, auth_result=auth_result,
csr=csr_mock, csr=csr,
group_prefixes={'nv': 'nova', 'sw': 'swift'} group_prefixes={'nv': 'nova', 'sw': 'swift'}
) )
) )
@ -332,50 +241,41 @@ class TestValidators(unittest.TestCase):
auth_result = mock.Mock() auth_result = mock.Mock()
auth_result.groups = ['glance'] auth_result.groups = ['glance']
cn_mock = mock.MagicMock() csr = x509_csr.X509Csr()
cn_mock.get_value.return_value = 'nv-master.test.com' name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, "nv-master.test.com")
csr_config = {
'get_subject.return_value.get_entries_by_nid.return_value':
[cn_mock],
}
csr_mock = mock.MagicMock(**csr_config)
with self.assertRaises(validators.ValidationError) as e: with self.assertRaises(validators.ValidationError) as e:
validators.server_group( validators.server_group(
auth_result=auth_result, auth_result=auth_result,
csr=csr_mock, csr=csr,
group_prefixes={'nv': 'nova', 'sw': 'swift'}) group_prefixes={'nv': 'nova', 'sw': 'swift'})
self.assertEqual("Server prefix doesn't match user groups", self.assertEqual("Server prefix doesn't match user groups",
str(e.exception)) str(e.exception))
def test_extensions_bad(self): def test_extensions_bad(self):
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_name.return_value = 'BAD' ext = x509_ext.X509ExtensionKeyUsage()
ext_mock.get_value.return_value = 'BAD' ext.set_usage('keyCertSign', True)
csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e: with self.assertRaises(validators.ValidationError) as e:
validators.extensions( validators.extensions(
csr=csr_mock, csr=csr,
allowed_extensions=['GOOD-1', 'GOOD-2']) allowed_extensions=['basicConstraints', 'nameConstraints'])
self.assertEqual("Extension 'BAD' not allowed", str(e.exception)) self.assertEqual("Extension 'keyUsage' not allowed", str(e.exception))
def test_extensions_good(self): def test_extensions_good(self):
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_name.return_value = 'GOOD-1' ext = x509_ext.X509ExtensionKeyUsage()
ext_mock.get_value.return_value = 'GOOD-1' ext.set_usage('keyCertSign', True)
csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
self.assertEqual( self.assertEqual(
None, None,
validators.extensions( validators.extensions(
csr=csr_mock, csr=csr,
allowed_extensions=['GOOD-1', 'GOOD-2'] allowed_extensions=['basicConstraints', 'keyUsage']
) )
) )
@ -384,204 +284,158 @@ class TestValidators(unittest.TestCase):
'Non Repudiation', 'Non Repudiation',
'Key Encipherment'] 'Key Encipherment']
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_name.return_value = 'keyUsage' ext = x509_ext.X509ExtensionKeyUsage()
ext_mock.get_value.return_value = 'Domination' ext.set_usage('keyCertSign', True)
csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e: with self.assertRaises(validators.ValidationError) as e:
validators.key_usage( validators.key_usage(
csr=csr_mock, csr=csr,
allowed_usage=allowed_usage) allowed_usage=allowed_usage)
self.assertEqual("Found some not allowed key usages: " self.assertEqual("Found some not allowed key usages: "
"Domination", str(e.exception)) "keyCertSign", str(e.exception))
def test_key_usage_good(self): def test_key_usage_good(self):
allowed_usage = ['Digital Signature', allowed_usage = ['Digital Signature',
'Non Repudiation', 'Non Repudiation',
'Key Encipherment'] 'Key Encipherment']
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_name.return_value = 'keyUsage' ext = x509_ext.X509ExtensionKeyUsage()
ext_mock.get_value.return_value = 'Key Encipherment, Digital Signature' ext.set_usage('keyEncipherment', True)
ext.set_usage('digitalSignature', True)
csr_mock = mock.MagicMock() csr.add_extension(ext)
csr_mock.get_extensions.return_value = [ext_mock]
self.assertEqual( self.assertEqual(
None, None,
validators.key_usage( validators.key_usage(
csr=csr_mock, csr=csr,
allowed_usage=allowed_usage allowed_usage=allowed_usage
) )
) )
def test_ca_status_good1(self): def test_ca_status_good1(self):
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_name.return_value = 'basicConstraints' ext = x509_ext.X509ExtensionBasicConstraints()
ext_mock.get_value.return_value = 'CA:TRUE' ext.set_ca(True)
csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
self.assertEqual( self.assertEqual(
None, None,
validators.ca_status( validators.ca_status(
csr=csr_mock, csr=csr,
ca_requested=True ca_requested=True
) )
) )
def test_ca_status_good2(self): def test_ca_status_good2(self):
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_name.return_value = 'basicConstraints' ext = x509_ext.X509ExtensionBasicConstraints()
ext_mock.get_value.return_value = 'CA:FALSE' ext.set_ca(False)
csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
self.assertEqual( self.assertEqual(
None, None,
validators.ca_status( validators.ca_status(
csr=csr_mock, csr=csr,
ca_requested=False ca_requested=False
) )
) )
def test_ca_status_bad(self): def test_ca_status_forbidden(self):
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_name.return_value = 'basicConstraints' ext = x509_ext.X509ExtensionBasicConstraints()
ext_mock.get_value.return_value = 'CA:FALSE' ext.set_ca(True)
csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e: with self.assertRaises(validators.ValidationError) as e:
validators.ca_status( validators.ca_status(
csr=csr_mock, csr=csr,
ca_requested=True) ca_requested=False)
self.assertEqual("Invalid CA status, 'CA:FALSE' requested", self.assertEqual("CA status requested, but not allowed",
str(e.exception)) str(e.exception))
def test_ca_status_bad_format1(self): def test_ca_status_bad(self):
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_name.return_value = 'basicConstraints' ext = x509_ext.X509ExtensionBasicConstraints()
ext_mock.get_value.return_value = 'CA~FALSE' ext.set_ca(False)
csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e: with self.assertRaises(validators.ValidationError) as e:
validators.ca_status( validators.ca_status(
csr=csr_mock, csr=csr,
ca_requested=False) ca_requested=True)
self.assertEqual("Invalid basic constraints flag", str(e.exception)) self.assertEqual("CA flags required",
str(e.exception))
def test_ca_status_bad_format2(self):
ext_mock = mock.MagicMock()
ext_mock.get_name.return_value = 'basicConstraints'
ext_mock.get_value.return_value = 'CA:FALSE:DERP'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e:
validators.ca_status(
csr=csr_mock,
ca_requested=False)
self.assertEqual("Invalid basic constraints flag", str(e.exception))
def test_ca_status_pathlen(self): def test_ca_status_pathlen(self):
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_name.return_value = 'basicConstraints' ext = x509_ext.X509ExtensionBasicConstraints()
ext_mock.get_value.return_value = 'pathlen:somthing' ext.set_path_len_constraint(1)
csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
self.assertEqual( self.assertEqual(
None, None,
validators.ca_status( validators.ca_status(
csr=csr_mock, csr=csr,
ca_requested=False ca_requested=False
) )
) )
def test_ca_status_bad_value(self):
ext_mock = mock.MagicMock()
ext_mock.get_name.return_value = 'basicConstraints'
ext_mock.get_value.return_value = 'BAD:VALUE'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e:
validators.ca_status(
csr=csr_mock,
ca_requested=False)
self.assertEqual("Invalid basic constraints option", str(e.exception))
def test_ca_status_key_usage_bad1(self): def test_ca_status_key_usage_bad1(self):
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_name.return_value = 'keyUsage' ext = x509_ext.X509ExtensionKeyUsage()
ext_mock.get_value.return_value = 'Certificate Sign' ext.set_usage('keyCertSign', True)
csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e: with self.assertRaises(validators.ValidationError) as e:
validators.ca_status( validators.ca_status(
csr=csr_mock, csr=csr,
ca_requested=False) ca_requested=False)
self.assertEqual("Key usage doesn't match requested CA status " self.assertEqual("Key usage doesn't match requested CA status "
"(keyCertSign/cRLSign: True/False)", str(e.exception)) "(keyCertSign/cRLSign: True/False)", str(e.exception))
def test_ca_status_key_usage_good1(self): def test_ca_status_key_usage_good1(self):
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_name.return_value = 'keyUsage' ext = x509_ext.X509ExtensionKeyUsage()
ext_mock.get_value.return_value = 'Certificate Sign' ext.set_usage('keyCertSign', True)
csr.add_extension(ext)
csr_mock = mock.MagicMock() self.assertEqual(
csr_mock.get_extensions.return_value = [ext_mock] None,
with self.assertRaises(validators.ValidationError) as e:
validators.ca_status( validators.ca_status(
csr=csr_mock, csr=csr,
ca_requested=True) ca_requested=True
self.assertEqual("Key usage doesn't match requested CA status " )
"(keyCertSign/cRLSign: True/False)", str(e.exception)) )
def test_ca_status_key_usage_bad2(self): def test_ca_status_key_usage_bad2(self):
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_name.return_value = 'keyUsage' ext = x509_ext.X509ExtensionKeyUsage()
ext_mock.get_value.return_value = 'CRL Sign' ext.set_usage('cRLSign', True)
csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e: with self.assertRaises(validators.ValidationError) as e:
validators.ca_status( validators.ca_status(
csr=csr_mock, csr=csr,
ca_requested=False) ca_requested=False)
self.assertEqual("Key usage doesn't match requested CA status " self.assertEqual("Key usage doesn't match requested CA status "
"(keyCertSign/cRLSign: False/True)", str(e.exception)) "(keyCertSign/cRLSign: False/True)", str(e.exception))
def test_ca_status_key_usage_good2(self): def test_ca_status_key_usage_good2(self):
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_name.return_value = 'keyUsage' ext = x509_ext.X509ExtensionKeyUsage()
ext_mock.get_value.return_value = 'CRL Sign' ext.set_usage('cRLSign', True)
csr.add_extension(ext)
csr_mock = mock.MagicMock() self.assertEqual(
csr_mock.get_extensions.return_value = [ext_mock] None,
with self.assertRaises(validators.ValidationError) as e:
validators.ca_status( validators.ca_status(
csr=csr_mock, csr=csr,
ca_requested=True) ca_requested=True
self.assertEqual("Key usage doesn't match requested CA status " )
"(keyCertSign/cRLSign: False/True)", str(e.exception)) )
def test_source_cidrs_good(self): def test_source_cidrs_good(self):
request = mock.Mock(client_addr='127.0.0.1') request = mock.Mock(client_addr='127.0.0.1')
@ -612,99 +466,65 @@ class TestValidators(unittest.TestCase):
str(e.exception)) str(e.exception))
def test_blacklist_names_good(self): def test_blacklist_names_good(self):
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_value.return_value = 'DNS:blah.good' ext = x509_ext.X509ExtensionSubjectAltName()
ext_mock.get_name.return_value = 'subjectAltName' ext.add_dns_id('blah.good')
csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
self.assertEqual( self.assertEqual(
None, None,
validators.blacklist_names( validators.blacklist_names(
csr=csr_mock, csr=csr,
domains=['.bad'], domains=['.bad'],
) )
) )
def test_blacklist_names_bad(self): def test_blacklist_names_bad(self):
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_value.return_value = 'DNS:blah.bad' ext = x509_ext.X509ExtensionSubjectAltName()
ext_mock.get_name.return_value = 'subjectAltName' ext.add_dns_id('blah.bad')
csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError): with self.assertRaises(validators.ValidationError):
validators.blacklist_names( validators.blacklist_names(
csr=csr_mock, csr=csr,
domains=['.bad'], domains=['.bad'],
) )
def test_blacklist_names_bad_cn(self): def test_blacklist_names_bad_cn(self):
cn_mock = mock.MagicMock() csr = x509_csr.X509Csr()
cn_mock.get_value.return_value = 'blah.bad' name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, "blah.bad")
csr_config = {
'get_subject.return_value.get_entries_by_nid.return_value':
[cn_mock],
}
csr_mock = mock.MagicMock(**csr_config)
with self.assertRaises(validators.ValidationError): with self.assertRaises(validators.ValidationError):
validators.blacklist_names( validators.blacklist_names(
csr=csr_mock, csr=csr,
domains=['.bad'], domains=['.bad'],
) )
def test_blacklist_names_mix(self): def test_blacklist_names_mix(self):
ext1_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext1_mock.get_value.return_value = 'DNS:blah.good' ext = x509_ext.X509ExtensionSubjectAltName()
ext1_mock.get_name.return_value = 'subjectAltName' ext.add_dns_id('blah.bad')
ext.add_dns_id('blah.good')
ext2_mock = mock.MagicMock() csr.add_extension(ext)
ext2_mock.get_value.return_value = 'DNS:blah.bad'
ext2_mock.get_name.return_value = 'subjectAltName'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext1_mock, ext2_mock]
with self.assertRaises(validators.ValidationError): with self.assertRaises(validators.ValidationError):
validators.blacklist_names( validators.blacklist_names(
csr=csr_mock, csr=csr,
domains=['.bad'], domains=['.bad'],
) )
def test_blacklist_names_ignore_unknown(self):
# only validate the DNS type - other types may look like domains
# by accident
ext_mock = mock.MagicMock()
ext_mock.get_value.return_value = 'RANDOM_TYPE:random.bad'
ext_mock.get_name.return_value = 'subjectAltName'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
self.assertEqual(
None,
validators.blacklist_names(
csr=csr_mock,
domains=['.bad'],
)
)
def test_blacklist_names_empty_list(self): def test_blacklist_names_empty_list(self):
# empty blacklist should pass everything through # empty blacklist should pass everything through
ext_mock = mock.MagicMock() csr = x509_csr.X509Csr()
ext_mock.get_value.return_value = 'DNS:some.name' ext = x509_ext.X509ExtensionSubjectAltName()
ext_mock.get_name.return_value = 'subjectAltName' ext.add_dns_id('blah.good')
csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
self.assertEqual( self.assertEqual(
None, None,
validators.blacklist_names( validators.blacklist_names(
csr=csr_mock, csr=csr,
) )
) )