Adds WinRM Listener plugin

This commit is contained in:
Alessandro Pilotti 2013-12-16 02:47:06 +02:00
parent 02a2a77e38
commit 44af45a1f8
7 changed files with 819 additions and 13 deletions

View File

@ -18,11 +18,11 @@ import base64
import os
import subprocess
PROTOCOL_TCP = "TCP"
PROTOCOL_UDP = "UDP"
class BaseOSUtils(object):
PROTOCOL_TCP = "TCP"
PROTOCOL_UDP = "UDP"
def reboot(self):
raise NotImplementedError()

View File

@ -158,6 +158,19 @@ class WindowsUtils(base.BaseOSUtils):
DRIVE_CDROM = 5
SERVICE_STATUS_STOPPED = "Stopped"
SERVICE_STATUS_START_PENDING = "Start Pending"
SERVICE_STATUS_STOP_PENDING = "Stop Pending"
SERVICE_STATUS_RUNNING = "Running"
SERVICE_STATUS_CONTINUE_PENDING = "Continue Pending"
SERVICE_STATUS_PAUSE_PENDING = "Pause Pending"
SERVICE_STATUS_PAUSED = "Paused"
SERVICE_STATUS_UNKNOWN = "Unknown"
SERVICE_START_MODE_AUTOMATIC = "Automatic"
SERVICE_START_MODE_MANUAL = "Manual"
SERVICE_START_MODE_DISABLED = "Disabled"
ComputerNamePhysicalDnsHostname = 5
_config_key = 'SOFTWARE\\Cloudbase Solutions\\Cloudbase-Init\\'
@ -417,12 +430,42 @@ class WindowsUtils(base.BaseOSUtils):
else:
raise ex
def _stop_service(self, service_name):
LOG.debug('Stopping service %s' % service_name)
def _get_service(self, service_name):
conn = wmi.WMI(moniker='//./root/cimv2')
service = conn.Win32_Service(Name=service_name)[0]
service_list = conn.Win32_Service(Name=service_name)
if len(service_list):
return service_list[0]
def check_service_exists(self, service_name):
return self._get_service(service_name) is not None
def get_service_status(self, service_name):
service = self._get_service(service_name)
return service.State
def get_service_start_mode(self, service_name):
service = self._get_service(service_name)
return service.StartMode
def set_service_start_mode(self, service_name, start_mode):
#TODO(alexpilotti): Handle the "Delayed Start" case
service = self._get_service(service_name)
(ret_val,) = service.ChangeStartMode(start_mode)
if ret_val != 0:
raise Exception('Setting service %(service_name)s start mode '
'failed with return value: %(ret_val)d' % locals())
def start_service(self, service_name):
LOG.debug('Starting service %s' % service_name)
service = self._get_service(service_name)
(ret_val,) = service.StartService()
if ret_val != 0:
raise Exception('Starting service %(service_name)s failed with '
'return value: %(ret_val)d' % locals())
def stop_service(self, service_name):
LOG.debug('Stopping service %s' % service_name)
service = self._get_service(service_name)
(ret_val,) = service.StopService()
if ret_val != 0:
raise Exception('Stopping service %(service_name)s failed with '
@ -432,7 +475,7 @@ class WindowsUtils(base.BaseOSUtils):
# Wait for the service to start. Polling the service "Started" property
# is not enough
time.sleep(3)
self._stop_service(self._service_name)
self.stop_service(self._service_name)
def get_default_gateway(self):
default_routes = [r for r in self._get_ipv4_routing_table()
@ -475,9 +518,10 @@ class WindowsUtils(base.BaseOSUtils):
'Error: %s' % err)
forward_table = p_forward_table.contents
table = ctypes.cast(ctypes.addressof(forward_table.table),
ctypes.POINTER(Win32_MIB_IPFORWARDROW *
forward_table.dwNumEntries)).contents
table = ctypes.cast(
ctypes.addressof(forward_table.table),
ctypes.POINTER(Win32_MIB_IPFORWARDROW *
forward_table.dwNumEntries)).contents
i = 0
while i < forward_table.dwNumEntries:
@ -578,9 +622,9 @@ class WindowsUtils(base.BaseOSUtils):
self.DRIVE_CDROM]
def _get_fw_protocol(self, protocol):
if protocol == base.PROTOCOL_TCP:
if protocol == self.PROTOCOL_TCP:
fw_protocol = self._FW_IP_PROTOCOL_TCP
elif protocol == base.PROTOCOL_UDP:
elif protocol == self.PROTOCOL_UDP:
fw_protocol = self._FW_IP_PROTOCOL_UDP
else:
raise NotImplementedError("Unsupported protocol")

View File

@ -29,6 +29,8 @@ opts = [
'cloudbaseinit.plugins.windows.extendvolumes.ExtendVolumesPlugin',
'cloudbaseinit.plugins.windows.userdata.UserDataPlugin',
'cloudbaseinit.plugins.windows.setuserpassword.SetUserPasswordPlugin',
'cloudbaseinit.plugins.windows.winrmlistener.'
'ConfigWinRMListenerPlugin',
],
help='List of enabled plugin classes, '
'to executed in the provided order'),

View File

@ -0,0 +1,229 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2013 Cloudbase Solutions Srl
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import ctypes
from ctypes import windll
from ctypes import wintypes
class CryptoAPIException(Exception):
def __init__(self):
message = self._get_windows_error()
super(CryptoAPIException, self).__init__(message)
def _get_windows_error(self):
err_code = GetLastError()
return "CryptoAPI error: 0x%0x" % err_code
class SYSTEMTIME(ctypes.Structure):
_fields_ = [
('wYear', wintypes.WORD),
('wMonth', wintypes.WORD),
('wDayOfWeek', wintypes.WORD),
('wDay', wintypes.WORD),
('wHour', wintypes.WORD),
('wMinute', wintypes.WORD),
('wSecond', wintypes.WORD),
('wMilliseconds', wintypes.WORD),
]
class CERT_CONTEXT(ctypes.Structure):
_fields_ = [
('dwCertEncodingType', wintypes.DWORD),
('pbCertEncoded', ctypes.POINTER(wintypes.BYTE)),
('cbCertEncoded', wintypes.DWORD),
('pCertInfo', ctypes.c_void_p),
('hCertStore', wintypes.HANDLE),
]
class CRYPTOAPI_BLOB(ctypes.Structure):
_fields_ = [
('cbData', wintypes.DWORD),
('pbData', ctypes.POINTER(wintypes.BYTE)),
]
class CRYPT_ALGORITHM_IDENTIFIER(ctypes.Structure):
_fields_ = [
('pszObjId', wintypes.LPSTR),
('Parameters', CRYPTOAPI_BLOB),
]
class CRYPT_KEY_PROV_PARAM(ctypes.Structure):
_fields_ = [
('dwParam', wintypes.DWORD),
('pbData', ctypes.POINTER(wintypes.BYTE)),
('cbData', wintypes.DWORD),
('dwFlags', wintypes.DWORD),
]
class CRYPT_KEY_PROV_INFO(ctypes.Structure):
_fields_ = [
('pwszContainerName', wintypes.LPWSTR),
('pwszProvName', wintypes.LPWSTR),
('dwProvType', wintypes.DWORD),
('dwFlags', wintypes.DWORD),
('cProvParam', wintypes.DWORD),
('cProvParam', ctypes.POINTER(CRYPT_KEY_PROV_PARAM)),
('dwKeySpec', wintypes.DWORD),
]
AT_SIGNATURE = 2
CERT_NAME_UPN_TYPE = 8
CERT_SHA1_HASH_PROP_ID = 3
CERT_STORE_ADD_REPLACE_EXISTING = 3
CERT_STORE_PROV_SYSTEM = wintypes.LPSTR(10)
CERT_SYSTEM_STORE_CURRENT_USER = 65536
CERT_SYSTEM_STORE_LOCAL_MACHINE = 131072
CERT_X500_NAME_STR = 3
CRYPT_MACHINE_KEYSET = 32
CRYPT_NEWKEYSET = 8
CRYPT_STRING_BASE64 = 1
PKCS_7_ASN_ENCODING = 65536
PROV_RSA_FULL = 1
X509_ASN_ENCODING = 1
szOID_PKIX_KP_SERVER_AUTH = "1.3.6.1.5.5.7.3.1"
szOID_RSA_SHA1RSA = "1.2.840.113549.1.1.5"
advapi32 = windll.advapi32
crypt32 = windll.crypt32
kernel32 = windll.kernel32
advapi32.CryptAcquireContextW.restype = wintypes.BOOL
advapi32.CryptAcquireContextW.argtypes = [wintypes.HANDLE, wintypes.LPCWSTR,
wintypes.LPCWSTR, wintypes.DWORD,
wintypes.DWORD]
CryptAcquireContext = advapi32.CryptAcquireContextW
advapi32.CryptReleaseContext.restype = wintypes.BOOL
advapi32.CryptReleaseContext.argtypes = [wintypes.HANDLE, wintypes.DWORD]
CryptReleaseContext = advapi32.CryptReleaseContext
advapi32.CryptGenKey.restype = wintypes.BOOL
advapi32.CryptGenKey.argtypes = [wintypes.HANDLE,
wintypes.DWORD,
wintypes.DWORD,
ctypes.POINTER(wintypes.HANDLE)]
CryptGenKey = advapi32.CryptGenKey
advapi32.CryptDestroyKey.restype = wintypes.BOOL
advapi32.CryptDestroyKey.argtypes = [wintypes.HANDLE]
CryptDestroyKey = advapi32.CryptDestroyKey
crypt32.CertStrToNameW.restype = wintypes.BOOL
crypt32.CertStrToNameW.argtypes = [wintypes.DWORD, wintypes.LPCWSTR,
wintypes.DWORD, ctypes.c_void_p,
ctypes.POINTER(wintypes.BYTE),
ctypes.POINTER(wintypes.DWORD),
ctypes.POINTER(wintypes.LPCWSTR)]
CertStrToName = crypt32.CertStrToNameW
# TODO(alexpilotti): this is not a CryptoAPI funtion, putting it in a separate
# module would be more correct
kernel32.GetSystemTime.restype = None
kernel32.GetSystemTime.argtypes = [ctypes.POINTER(SYSTEMTIME)]
GetSystemTime = kernel32.GetSystemTime
# TODO(alexpilotti): this is not a CryptoAPI funtion, putting it in a separate
# module would be more correct
kernel32.GetLastError.restype = wintypes.DWORD
kernel32.GetLastError.argtypes = []
GetLastError = kernel32.GetLastError
crypt32.CertCreateSelfSignCertificate.restype = ctypes.POINTER(CERT_CONTEXT)
crypt32.CertCreateSelfSignCertificate.argtypes = [
wintypes.HANDLE,
ctypes.POINTER(CRYPTOAPI_BLOB),
wintypes.DWORD,
ctypes.POINTER(CRYPT_KEY_PROV_INFO),
ctypes.POINTER(CRYPT_ALGORITHM_IDENTIFIER),
ctypes.POINTER(SYSTEMTIME),
ctypes.POINTER(SYSTEMTIME),
# PCERT_EXTENSIONS
ctypes.c_void_p]
CertCreateSelfSignCertificate = crypt32.CertCreateSelfSignCertificate
crypt32.CertAddEnhancedKeyUsageIdentifier.restype = wintypes.BOOL
crypt32.CertAddEnhancedKeyUsageIdentifier.argtypes = [
ctypes.POINTER(CERT_CONTEXT),
wintypes.LPCSTR]
CertAddEnhancedKeyUsageIdentifier = crypt32.CertAddEnhancedKeyUsageIdentifier
crypt32.CertOpenStore.restype = wintypes.HANDLE
crypt32.CertOpenStore.argtypes = [wintypes.LPCSTR, wintypes.DWORD,
wintypes.HANDLE, wintypes.DWORD,
ctypes.c_void_p]
CertOpenStore = crypt32.CertOpenStore
crypt32.CertAddCertificateContextToStore.restype = wintypes.BOOL
crypt32.CertAddCertificateContextToStore.argtypes = [
wintypes.HANDLE,
ctypes.POINTER(CERT_CONTEXT),
wintypes.DWORD,
ctypes.POINTER(CERT_CONTEXT)]
CertAddCertificateContextToStore = crypt32.CertAddCertificateContextToStore
crypt32.CryptStringToBinaryA.restype = wintypes.BOOL
crypt32.CryptStringToBinaryA.argtypes = [wintypes.LPCSTR,
wintypes.DWORD,
wintypes.DWORD,
ctypes.POINTER(wintypes.BYTE),
ctypes.POINTER(wintypes.DWORD),
ctypes.POINTER(wintypes.DWORD),
ctypes.POINTER(wintypes.DWORD)]
CryptStringToBinaryA = crypt32.CryptStringToBinaryA
crypt32.CertAddEncodedCertificateToStore.restype = wintypes.BOOL
crypt32.CertAddEncodedCertificateToStore.argtypes = [
wintypes.HANDLE,
wintypes.DWORD,
ctypes.POINTER(wintypes.BYTE),
wintypes.DWORD,
wintypes.DWORD,
ctypes.POINTER(ctypes.POINTER(CERT_CONTEXT))]
CertAddEncodedCertificateToStore = crypt32.CertAddEncodedCertificateToStore
crypt32.CertGetNameStringW.restype = wintypes.DWORD
crypt32.CertGetNameStringW.argtypes = [ctypes.POINTER(CERT_CONTEXT),
wintypes.DWORD,
wintypes.DWORD,
ctypes.c_void_p,
wintypes.LPWSTR,
wintypes.DWORD]
CertGetNameString = crypt32.CertGetNameStringW
crypt32.CertFreeCertificateContext.restype = wintypes.BOOL
crypt32.CertFreeCertificateContext.argtypes = [ctypes.POINTER(CERT_CONTEXT)]
CertFreeCertificateContext = crypt32.CertFreeCertificateContext
crypt32.CertCloseStore.restype = wintypes.BOOL
crypt32.CertCloseStore.argtypes = [wintypes.HANDLE, wintypes.DWORD]
CertCloseStore = crypt32.CertCloseStore
crypt32.CertGetCertificateContextProperty.restype = wintypes.BOOL
crypt32.CertGetCertificateContextProperty.argtypes = [
ctypes.POINTER(CERT_CONTEXT),
wintypes.DWORD,
ctypes.c_void_p,
ctypes.POINTER(wintypes.DWORD)]
CertGetCertificateContextProperty = crypt32.CertGetCertificateContextProperty

View File

@ -0,0 +1,155 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2012 Cloudbase Solutions Srl
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import pywintypes
import re
from win32com import client
from xml.etree import ElementTree
CBT_HARDENING_LEVEL_NONE = "none"
CBT_HARDENING_LEVEL_RELAXED = "relaxed"
CBT_HARDENING_LEVEL_STRICT = "strict"
LISTENER_PROTOCOL_HTTP = "HTTP"
LISTENER_PROTOCOL_HTTPS = "HTTPS"
class WinRMConfig(object):
_SERVICE_AUTH_URI = 'winrm/Config/Service/Auth'
_SERVICE_LISTENER_URI = 'winrm/Config/Listener?Address=*+Transport=%s'
def _get_wsman_session(self):
wsman = client.Dispatch('WSMan.Automation')
return wsman.CreateSession()
def _get_node_tag(self, tag):
return re.match("^{.*}(.*)$", tag).groups(1)[0]
def _parse_listener_xml(self, data_xml):
listening_on = []
data = {"ListeningOn": listening_on}
ns = {'cfg':
'http://schemas.microsoft.com/wbem/wsman/1/config/listener'}
tree = ElementTree.fromstring(data_xml)
for node in tree:
tag = self._get_node_tag(node.tag)
if tag == "ListeningOn":
listening_on.append(node.text)
elif tag == "Enabled":
if node.text == "true":
value = True
else:
value = False
data[tag] = value
elif tag == "Port":
data[tag] = int(node.text)
else:
data[tag] = node.text
return data
def get_listener(self, protocol=LISTENER_PROTOCOL_HTTPS):
session = self._get_wsman_session()
resourceUri = self._SERVICE_LISTENER_URI % protocol
try:
data_xml = session.Get(resourceUri)
except pywintypes.com_error, ex:
if len(ex.excepinfo) > 5 and ex.excepinfo[5] == -2144108544:
return None
else:
raise
return self._parse_listener_xml(data_xml)
def delete_listener(self, protocol=LISTENER_PROTOCOL_HTTPS):
session = self._get_wsman_session()
resourceUri = self._SERVICE_LISTENER_URI % protocol
session.Delete(resourceUri)
def create_listener(self, protocol=LISTENER_PROTOCOL_HTTPS, enabled=True,
cert_thumbprint=None):
session = self._get_wsman_session()
resource_uri = self._SERVICE_LISTENER_URI % protocol
if enabled:
enabled_str = "true"
else:
enabled_str = "false"
session.Create(
resource_uri,
'<p:Listener xmlns:p="http://schemas.microsoft.com/'
'wbem/wsman/1/config/listener.xsd">'
'<p:Enabled>%(enabled_str)s</p:Enabled>'
'<p:CertificateThumbPrint>%(cert_thumbprint)s'
'</p:CertificateThumbPrint>'
'<p:URLPrefix>wsman</p:URLPrefix>'
'</p:Listener>' % {"enabled_str": enabled_str,
"cert_thumbprint": cert_thumbprint})
def get_auth_config(self):
data = {}
session = self._get_wsman_session()
data_xml = session.Get(self._SERVICE_AUTH_URI)
tree = ElementTree.fromstring(data_xml)
for node in tree:
tag = self._get_node_tag(node.tag)
value_str = node.text.lower()
if value_str == "true":
value = True
elif value_str == "false":
value = False
else:
value = value_str
data[tag] = value
return data
def set_auth_config(self, basic=None, kerberos=None, negotiate=None,
certificate=None, credSSP=None,
cbt_hardening_level=None):
tag_map = {'Basic': basic,
'Kerberos': kerberos,
'Negotiate': negotiate,
'Certificate': certificate,
'CredSSP': credSSP,
'CbtHardeningLevel': cbt_hardening_level}
session = self._get_wsman_session()
data_xml = session.Get(self._SERVICE_AUTH_URI)
ns = {'cfg':
'http://schemas.microsoft.com/wbem/wsman/1/config/service/auth'}
tree = ElementTree.fromstring(data_xml)
for (tag, value) in tag_map.items():
if value is not None:
if value:
new_value = "true"
else:
new_value = "false"
node = tree.find('.//cfg:%s' % tag, namespaces=ns)
if node.text.lower() != new_value:
node.text = new_value
data_xml = ElementTree.tostring(tree)
session.Put(self._SERVICE_AUTH_URI, data_xml)

View File

@ -0,0 +1,91 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2013 Cloudbase Solutions Srl
#
# Licensed under the Apache License, Version 2.0 (the 'License'); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from cloudbaseinit.openstack.common import cfg
from cloudbaseinit.openstack.common import log as logging
from cloudbaseinit.osutils import factory as osutils_factory
from cloudbaseinit.plugins import base
from cloudbaseinit.plugins.windows import x509
from cloudbaseinit.plugins.windows import winrmconfig
LOG = logging.getLogger(__name__)
opts = [
cfg.BoolOpt('winrm_enable_basic_auth', default=True,
help='Enables basic authentication for the WinRM '
'HTTPS listener'),
]
CONF = cfg.CONF
CONF.register_opts(opts)
LOG = logging.getLogger(__name__)
class ConfigWinRMListenerPlugin(base.BasePlugin):
_cert_subject = "CN=Cloudbase-Init"
_winrm_service_name = "WinRM"
def _check_winrm_service(self, osutils):
if not osutils.check_service_exists(self._winrm_service_name):
LOG.warn("Cannot configure the WinRM listener as the service "
"is not available")
return False
start_mode = osutils.get_service_start_mode(self._winrm_service_name)
if start_mode in [osutils.SERVICE_START_MODE_MANUAL,
osutils.SERVICE_START_MODE_DISABLED]:
# TODO(alexpilotti) Set to "Delayed Start"
osutils.set_service_start_mode(
self._winrm_service_name,
osutils.SERVICE_START_MODE_AUTOMATIC)
service_status = osutils.get_service_status(self._winrm_service_name)
if service_status == osutils.SERVICE_STATUS_STOPPED:
osutils.start_service(self._winrm_service_name)
return True
def execute(self, service):
osutils = osutils_factory.OSUtilsFactory().get_os_utils()
if not self._check_winrm_service(osutils):
return (base.PLUGIN_EXECUTE_ON_NEXT_BOOT, False)
winrm_config = winrmconfig.WinRMConfig()
winrm_config.set_auth_config(basic=CONF.winrm_enable_basic_auth)
cert_manager = x509.CryptoAPICertManager()
cert_thumbprint = cert_manager.create_self_signed_cert(
self._cert_subject)
protocol = winrmconfig.LISTENER_PROTOCOL_HTTPS
if winrm_config.get_listener(protocol=protocol):
winrm_config.delete_listener(protocol=protocol)
winrm_config.create_listener(
cert_thumbprint=cert_thumbprint,
protocol=protocol)
listener_config = winrm_config.get_listener(protocol=protocol)
listener_port = listener_config.get("Port")
rule_name = "WinRM %s" % protocol
osutils.firewall_create_rule(rule_name, listener_port,
osutils.PROTOCOL_TCP)
return (base.PLUGIN_EXECUTION_DONE, False)

View File

@ -0,0 +1,285 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2013 Cloudbase Solutions Srl
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import copy
import ctypes
import uuid
from ctypes import wintypes
from cloudbaseinit.plugins.windows import cryptoapi
malloc = ctypes.cdll.msvcrt.malloc
malloc.restype = ctypes.c_void_p
malloc.argtypes = [ctypes.c_size_t]
free = ctypes.cdll.msvcrt.free
free.restype = None
free.argtypes = [ctypes.c_void_p]
class CryptoAPICertManager(object):
def _get_cert_thumprint(self, cert_context_p):
thumbprint = None
try:
thumprint_len = wintypes.DWORD()
if not cryptoapi.CertGetCertificateContextProperty(
cert_context_p,
cryptoapi.CERT_SHA1_HASH_PROP_ID,
None, ctypes.byref(thumprint_len)):
raise cryptoapi.CryptoAPIException()
thumbprint = malloc(thumprint_len)
if not cryptoapi.CertGetCertificateContextProperty(
cert_context_p,
cryptoapi.CERT_SHA1_HASH_PROP_ID,
thumbprint, ctypes.byref(thumprint_len)):
raise cryptoapi.CryptoAPIException()
thumbprint_ar = ctypes.cast(
thumbprint,
ctypes.POINTER(ctypes.c_ubyte *
thumprint_len.value)).contents
thumbprint_str = ""
for b in thumbprint_ar:
thumbprint_str += "%02x" % b
return thumbprint_str
finally:
if thumbprint:
free(thumbprint)
def _generate_key(self, container_name, machine_keyset):
crypt_prov_handle = wintypes.HANDLE()
key_handle = wintypes.HANDLE()
try:
flags = 0
if machine_keyset:
flags |= cryptoapi.CRYPT_MACHINE_KEYSET
if not cryptoapi.CryptAcquireContext(
ctypes.byref(crypt_prov_handle),
container_name,
None,
cryptoapi.PROV_RSA_FULL,
flags):
flags |= cryptoapi.CRYPT_NEWKEYSET
if not cryptoapi.CryptAcquireContext(
ctypes.byref(crypt_prov_handle),
container_name,
None,
cryptoapi.PROV_RSA_FULL,
flags):
raise cryptoapi.CryptoAPIException()
# RSA 2048 bits
if not cryptoapi.CryptGenKey(crypt_prov_handle,
cryptoapi.AT_SIGNATURE,
0x08000000, key_handle):
raise cryptoapi.CryptoAPIException()
finally:
if key_handle:
cryptoapi.CryptDestroyKey(key_handle)
if crypt_prov_handle:
cryptoapi.CryptReleaseContext(crypt_prov_handle, 0)
def create_self_signed_cert(self, subject, validity_years=10,
machine_keyset=True, store_name="MY"):
subject_encoded = None
cert_context_p = None
store_handle = None
container_name = str(uuid.uuid4())
self._generate_key(container_name, machine_keyset)
try:
subject_encoded_len = wintypes.DWORD()
if not cryptoapi.CertStrToName(cryptoapi.X509_ASN_ENCODING,
subject,
cryptoapi.CERT_X500_NAME_STR, None,
None,
ctypes.byref(subject_encoded_len),
None):
raise cryptoapi.CryptoAPIException()
subject_encoded = ctypes.cast(malloc(subject_encoded_len),
ctypes.POINTER(wintypes.BYTE))
if not cryptoapi.CertStrToName(cryptoapi.X509_ASN_ENCODING,
subject,
cryptoapi.CERT_X500_NAME_STR, None,
subject_encoded,
ctypes.byref(subject_encoded_len),
None):
raise cryptoapi.CryptoAPIException()
subject_blob = cryptoapi.CRYPTOAPI_BLOB()
subject_blob.cbData = subject_encoded_len
subject_blob.pbData = subject_encoded
key_prov_info = cryptoapi.CRYPT_KEY_PROV_INFO()
key_prov_info.pwszContainerName = container_name
key_prov_info.pwszProvName = None
key_prov_info.dwProvType = cryptoapi.PROV_RSA_FULL
key_prov_info.cProvParam = None
key_prov_info.rgProvParam = None
key_prov_info.dwKeySpec = cryptoapi.AT_SIGNATURE
if machine_keyset:
key_prov_info.dwFlags = cryptoapi.CRYPT_MACHINE_KEYSET
else:
key_prov_info.dwFlags = 0
sign_alg = cryptoapi.CRYPT_ALGORITHM_IDENTIFIER()
sign_alg.pszObjId = cryptoapi.szOID_RSA_SHA1RSA
start_time = cryptoapi.SYSTEMTIME()
cryptoapi.GetSystemTime(ctypes.byref(start_time))
end_time = copy.copy(start_time)
end_time.wYear += validity_years
cert_context_p = cryptoapi.CertCreateSelfSignCertificate(
None, ctypes.byref(subject_blob), 0,
ctypes.byref(key_prov_info),
ctypes.byref(sign_alg), ctypes.byref(start_time),
ctypes.byref(end_time), None)
if not cert_context_p:
raise cryptoapi.CryptoAPIException()
if not cryptoapi.CertAddEnhancedKeyUsageIdentifier(
cert_context_p, cryptoapi.szOID_PKIX_KP_SERVER_AUTH):
raise cryptoapi.CryptoAPIException()
if machine_keyset:
flags = cryptoapi.CERT_SYSTEM_STORE_LOCAL_MACHINE
else:
flags = cryptoapi.CERT_SYSTEM_STORE_CURRENT_USER
store_handle = cryptoapi.CertOpenStore(
cryptoapi.CERT_STORE_PROV_SYSTEM, 0, 0, flags,
unicode(store_name))
if not store_handle:
raise cryptoapi.CryptoAPIException()
if not cryptoapi.CertAddCertificateContextToStore(
store_handle, cert_context_p,
cryptoapi.CERT_STORE_ADD_REPLACE_EXISTING, None):
raise cryptoapi.CryptoAPIException()
return self._get_cert_thumprint(cert_context_p)
finally:
if store_handle:
cryptoapi.CertCloseStore(store_handle, 0)
if cert_context_p:
cryptoapi.CertFreeCertificateContext(cert_context_p)
if subject_encoded:
free(subject_encoded)
def _get_cert_base64(self, cert_data):
header = "-----BEGIN CERTIFICATE-----\n"
footer = "-----END CERTIFICATE-----\n"
base64_cert_data = cert_data
if base64_cert_data.startswith(header):
base64_cert_data = base64_cert_data[len(header):]
if base64_cert_data.endswith(footer):
base64_cert_data = base64_cert_data[:len(base64_cert_data) -
len(footer)]
return base64_cert_data.replace("\n", "")
def import_cert(self, cert_data, machine_keyset=True,
store_name="TrustedPeople"):
base64_cert_data = self._get_cert_base64(cert_data)
cert_encoded = None
store_handle = None
cert_context_p = None
try:
cert_encoded_len = wintypes.DWORD()
if not cryptoapi.CryptStringToBinaryA(
base64_cert_data, len(base64_cert_data),
cryptoapi.CRYPT_STRING_BASE64,
None, ctypes.byref(cert_encoded_len),
None, None):
raise cryptoapi.CryptoAPIException()
cert_encoded = ctypes.cast(malloc(cert_encoded_len),
ctypes.POINTER(wintypes.BYTE))
if not cryptoapi.CryptStringToBinaryA(
base64_cert_data, len(base64_cert_data),
cryptoapi.CRYPT_STRING_BASE64,
cert_encoded, ctypes.byref(cert_encoded_len),
None, None):
raise cryptoapi.CryptoAPIException()
if machine_keyset:
flags = cryptoapi.CERT_SYSTEM_STORE_LOCAL_MACHINE
else:
flags = cryptoapi.CERT_SYSTEM_STORE_CURRENT_USER
store_handle = cryptoapi.CertOpenStore(
cryptoapi.CERT_STORE_PROV_SYSTEM, 0, 0, flags,
unicode(store_name))
if not store_handle:
raise cryptoapi.CryptoAPIException()
cert_context_p = ctypes.POINTER(cryptoapi.CERT_CONTEXT)()
if not cryptoapi.CertAddEncodedCertificateToStore(
store_handle,
cryptoapi.X509_ASN_ENCODING |
cryptoapi.PKCS_7_ASN_ENCODING,
cert_encoded, cert_encoded_len,
cryptoapi.CERT_STORE_ADD_REPLACE_EXISTING,
ctypes.byref(cert_context_p)):
raise cryptoapi.CryptoAPIException()
# Get the UPN (1.3.6.1.4.1.311.20.2.3 OID) from the
# certificate subject alt name
upn = None
upn_len = cryptoapi.CertGetNameString(
cert_context_p,
cryptoapi.CERT_NAME_UPN_TYPE, 0,
None, None, 0)
if upn_len > 1:
upn_ar = ctypes.create_unicode_buffer(upn_len)
if cryptoapi.CertGetNameString(
cert_context_p,
cryptoapi.CERT_NAME_UPN_TYPE,
0, None, upn_ar, upn_len) != upn_len:
raise cryptoapi.CryptoAPIException()
upn = upn_ar.value
thumbprint = self._get_cert_thumprint(cert_context_p)
return (thumbprint, upn)
finally:
if cert_context_p:
cryptoapi.CertFreeCertificateContext(cert_context_p)
if store_handle:
cryptoapi.CertCloseStore(store_handle, 0)
if cert_encoded:
free(cert_encoded)