diff --git a/cloudbaseinit/osutils/base.py b/cloudbaseinit/osutils/base.py index 4eb66ac2..a9866896 100644 --- a/cloudbaseinit/osutils/base.py +++ b/cloudbaseinit/osutils/base.py @@ -18,11 +18,11 @@ import base64 import os import subprocess -PROTOCOL_TCP = "TCP" -PROTOCOL_UDP = "UDP" - class BaseOSUtils(object): + PROTOCOL_TCP = "TCP" + PROTOCOL_UDP = "UDP" + def reboot(self): raise NotImplementedError() diff --git a/cloudbaseinit/osutils/windows.py b/cloudbaseinit/osutils/windows.py index 5f9573ee..9ebf8ab9 100644 --- a/cloudbaseinit/osutils/windows.py +++ b/cloudbaseinit/osutils/windows.py @@ -158,6 +158,19 @@ class WindowsUtils(base.BaseOSUtils): DRIVE_CDROM = 5 + SERVICE_STATUS_STOPPED = "Stopped" + SERVICE_STATUS_START_PENDING = "Start Pending" + SERVICE_STATUS_STOP_PENDING = "Stop Pending" + SERVICE_STATUS_RUNNING = "Running" + SERVICE_STATUS_CONTINUE_PENDING = "Continue Pending" + SERVICE_STATUS_PAUSE_PENDING = "Pause Pending" + SERVICE_STATUS_PAUSED = "Paused" + SERVICE_STATUS_UNKNOWN = "Unknown" + + SERVICE_START_MODE_AUTOMATIC = "Automatic" + SERVICE_START_MODE_MANUAL = "Manual" + SERVICE_START_MODE_DISABLED = "Disabled" + ComputerNamePhysicalDnsHostname = 5 _config_key = 'SOFTWARE\\Cloudbase Solutions\\Cloudbase-Init\\' @@ -417,12 +430,42 @@ class WindowsUtils(base.BaseOSUtils): else: raise ex - def _stop_service(self, service_name): - LOG.debug('Stopping service %s' % service_name) - + def _get_service(self, service_name): conn = wmi.WMI(moniker='//./root/cimv2') - service = conn.Win32_Service(Name=service_name)[0] + service_list = conn.Win32_Service(Name=service_name) + if len(service_list): + return service_list[0] + def check_service_exists(self, service_name): + return self._get_service(service_name) is not None + + def get_service_status(self, service_name): + service = self._get_service(service_name) + return service.State + + def get_service_start_mode(self, service_name): + service = self._get_service(service_name) + return service.StartMode + + def set_service_start_mode(self, service_name, start_mode): + #TODO(alexpilotti): Handle the "Delayed Start" case + service = self._get_service(service_name) + (ret_val,) = service.ChangeStartMode(start_mode) + if ret_val != 0: + raise Exception('Setting service %(service_name)s start mode ' + 'failed with return value: %(ret_val)d' % locals()) + + def start_service(self, service_name): + LOG.debug('Starting service %s' % service_name) + service = self._get_service(service_name) + (ret_val,) = service.StartService() + if ret_val != 0: + raise Exception('Starting service %(service_name)s failed with ' + 'return value: %(ret_val)d' % locals()) + + def stop_service(self, service_name): + LOG.debug('Stopping service %s' % service_name) + service = self._get_service(service_name) (ret_val,) = service.StopService() if ret_val != 0: raise Exception('Stopping service %(service_name)s failed with ' @@ -432,7 +475,7 @@ class WindowsUtils(base.BaseOSUtils): # Wait for the service to start. Polling the service "Started" property # is not enough time.sleep(3) - self._stop_service(self._service_name) + self.stop_service(self._service_name) def get_default_gateway(self): default_routes = [r for r in self._get_ipv4_routing_table() @@ -475,9 +518,10 @@ class WindowsUtils(base.BaseOSUtils): 'Error: %s' % err) forward_table = p_forward_table.contents - table = ctypes.cast(ctypes.addressof(forward_table.table), - ctypes.POINTER(Win32_MIB_IPFORWARDROW * - forward_table.dwNumEntries)).contents + table = ctypes.cast( + ctypes.addressof(forward_table.table), + ctypes.POINTER(Win32_MIB_IPFORWARDROW * + forward_table.dwNumEntries)).contents i = 0 while i < forward_table.dwNumEntries: @@ -578,9 +622,9 @@ class WindowsUtils(base.BaseOSUtils): self.DRIVE_CDROM] def _get_fw_protocol(self, protocol): - if protocol == base.PROTOCOL_TCP: + if protocol == self.PROTOCOL_TCP: fw_protocol = self._FW_IP_PROTOCOL_TCP - elif protocol == base.PROTOCOL_UDP: + elif protocol == self.PROTOCOL_UDP: fw_protocol = self._FW_IP_PROTOCOL_UDP else: raise NotImplementedError("Unsupported protocol") diff --git a/cloudbaseinit/plugins/factory.py b/cloudbaseinit/plugins/factory.py index e7898014..fb88d73e 100644 --- a/cloudbaseinit/plugins/factory.py +++ b/cloudbaseinit/plugins/factory.py @@ -29,6 +29,8 @@ opts = [ 'cloudbaseinit.plugins.windows.extendvolumes.ExtendVolumesPlugin', 'cloudbaseinit.plugins.windows.userdata.UserDataPlugin', 'cloudbaseinit.plugins.windows.setuserpassword.SetUserPasswordPlugin', + 'cloudbaseinit.plugins.windows.winrmlistener.' + 'ConfigWinRMListenerPlugin', ], help='List of enabled plugin classes, ' 'to executed in the provided order'), diff --git a/cloudbaseinit/plugins/windows/cryptoapi.py b/cloudbaseinit/plugins/windows/cryptoapi.py new file mode 100644 index 00000000..84094469 --- /dev/null +++ b/cloudbaseinit/plugins/windows/cryptoapi.py @@ -0,0 +1,229 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 Cloudbase Solutions Srl +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import ctypes + +from ctypes import windll +from ctypes import wintypes + + +class CryptoAPIException(Exception): + def __init__(self): + message = self._get_windows_error() + super(CryptoAPIException, self).__init__(message) + + def _get_windows_error(self): + err_code = GetLastError() + return "CryptoAPI error: 0x%0x" % err_code + + +class SYSTEMTIME(ctypes.Structure): + _fields_ = [ + ('wYear', wintypes.WORD), + ('wMonth', wintypes.WORD), + ('wDayOfWeek', wintypes.WORD), + ('wDay', wintypes.WORD), + ('wHour', wintypes.WORD), + ('wMinute', wintypes.WORD), + ('wSecond', wintypes.WORD), + ('wMilliseconds', wintypes.WORD), + ] + + +class CERT_CONTEXT(ctypes.Structure): + _fields_ = [ + ('dwCertEncodingType', wintypes.DWORD), + ('pbCertEncoded', ctypes.POINTER(wintypes.BYTE)), + ('cbCertEncoded', wintypes.DWORD), + ('pCertInfo', ctypes.c_void_p), + ('hCertStore', wintypes.HANDLE), + ] + + +class CRYPTOAPI_BLOB(ctypes.Structure): + _fields_ = [ + ('cbData', wintypes.DWORD), + ('pbData', ctypes.POINTER(wintypes.BYTE)), + ] + + +class CRYPT_ALGORITHM_IDENTIFIER(ctypes.Structure): + _fields_ = [ + ('pszObjId', wintypes.LPSTR), + ('Parameters', CRYPTOAPI_BLOB), + ] + + +class CRYPT_KEY_PROV_PARAM(ctypes.Structure): + _fields_ = [ + ('dwParam', wintypes.DWORD), + ('pbData', ctypes.POINTER(wintypes.BYTE)), + ('cbData', wintypes.DWORD), + ('dwFlags', wintypes.DWORD), + ] + + +class CRYPT_KEY_PROV_INFO(ctypes.Structure): + _fields_ = [ + ('pwszContainerName', wintypes.LPWSTR), + ('pwszProvName', wintypes.LPWSTR), + ('dwProvType', wintypes.DWORD), + ('dwFlags', wintypes.DWORD), + ('cProvParam', wintypes.DWORD), + ('cProvParam', ctypes.POINTER(CRYPT_KEY_PROV_PARAM)), + ('dwKeySpec', wintypes.DWORD), + ] + + +AT_SIGNATURE = 2 +CERT_NAME_UPN_TYPE = 8 +CERT_SHA1_HASH_PROP_ID = 3 +CERT_STORE_ADD_REPLACE_EXISTING = 3 +CERT_STORE_PROV_SYSTEM = wintypes.LPSTR(10) +CERT_SYSTEM_STORE_CURRENT_USER = 65536 +CERT_SYSTEM_STORE_LOCAL_MACHINE = 131072 +CERT_X500_NAME_STR = 3 +CRYPT_MACHINE_KEYSET = 32 +CRYPT_NEWKEYSET = 8 +CRYPT_STRING_BASE64 = 1 +PKCS_7_ASN_ENCODING = 65536 +PROV_RSA_FULL = 1 +X509_ASN_ENCODING = 1 +szOID_PKIX_KP_SERVER_AUTH = "1.3.6.1.5.5.7.3.1" +szOID_RSA_SHA1RSA = "1.2.840.113549.1.1.5" + +advapi32 = windll.advapi32 +crypt32 = windll.crypt32 +kernel32 = windll.kernel32 + +advapi32.CryptAcquireContextW.restype = wintypes.BOOL +advapi32.CryptAcquireContextW.argtypes = [wintypes.HANDLE, wintypes.LPCWSTR, + wintypes.LPCWSTR, wintypes.DWORD, + wintypes.DWORD] +CryptAcquireContext = advapi32.CryptAcquireContextW + +advapi32.CryptReleaseContext.restype = wintypes.BOOL +advapi32.CryptReleaseContext.argtypes = [wintypes.HANDLE, wintypes.DWORD] +CryptReleaseContext = advapi32.CryptReleaseContext + +advapi32.CryptGenKey.restype = wintypes.BOOL +advapi32.CryptGenKey.argtypes = [wintypes.HANDLE, + wintypes.DWORD, + wintypes.DWORD, + ctypes.POINTER(wintypes.HANDLE)] +CryptGenKey = advapi32.CryptGenKey + +advapi32.CryptDestroyKey.restype = wintypes.BOOL +advapi32.CryptDestroyKey.argtypes = [wintypes.HANDLE] +CryptDestroyKey = advapi32.CryptDestroyKey + +crypt32.CertStrToNameW.restype = wintypes.BOOL +crypt32.CertStrToNameW.argtypes = [wintypes.DWORD, wintypes.LPCWSTR, + wintypes.DWORD, ctypes.c_void_p, + ctypes.POINTER(wintypes.BYTE), + ctypes.POINTER(wintypes.DWORD), + ctypes.POINTER(wintypes.LPCWSTR)] +CertStrToName = crypt32.CertStrToNameW + +# TODO(alexpilotti): this is not a CryptoAPI funtion, putting it in a separate +# module would be more correct +kernel32.GetSystemTime.restype = None +kernel32.GetSystemTime.argtypes = [ctypes.POINTER(SYSTEMTIME)] +GetSystemTime = kernel32.GetSystemTime + +# TODO(alexpilotti): this is not a CryptoAPI funtion, putting it in a separate +# module would be more correct +kernel32.GetLastError.restype = wintypes.DWORD +kernel32.GetLastError.argtypes = [] +GetLastError = kernel32.GetLastError + +crypt32.CertCreateSelfSignCertificate.restype = ctypes.POINTER(CERT_CONTEXT) +crypt32.CertCreateSelfSignCertificate.argtypes = [ + wintypes.HANDLE, + ctypes.POINTER(CRYPTOAPI_BLOB), + wintypes.DWORD, + ctypes.POINTER(CRYPT_KEY_PROV_INFO), + ctypes.POINTER(CRYPT_ALGORITHM_IDENTIFIER), + ctypes.POINTER(SYSTEMTIME), + ctypes.POINTER(SYSTEMTIME), + # PCERT_EXTENSIONS + ctypes.c_void_p] +CertCreateSelfSignCertificate = crypt32.CertCreateSelfSignCertificate + +crypt32.CertAddEnhancedKeyUsageIdentifier.restype = wintypes.BOOL +crypt32.CertAddEnhancedKeyUsageIdentifier.argtypes = [ + ctypes.POINTER(CERT_CONTEXT), + wintypes.LPCSTR] +CertAddEnhancedKeyUsageIdentifier = crypt32.CertAddEnhancedKeyUsageIdentifier + +crypt32.CertOpenStore.restype = wintypes.HANDLE +crypt32.CertOpenStore.argtypes = [wintypes.LPCSTR, wintypes.DWORD, + wintypes.HANDLE, wintypes.DWORD, + ctypes.c_void_p] +CertOpenStore = crypt32.CertOpenStore + +crypt32.CertAddCertificateContextToStore.restype = wintypes.BOOL +crypt32.CertAddCertificateContextToStore.argtypes = [ + wintypes.HANDLE, + ctypes.POINTER(CERT_CONTEXT), + wintypes.DWORD, + ctypes.POINTER(CERT_CONTEXT)] +CertAddCertificateContextToStore = crypt32.CertAddCertificateContextToStore + +crypt32.CryptStringToBinaryA.restype = wintypes.BOOL +crypt32.CryptStringToBinaryA.argtypes = [wintypes.LPCSTR, + wintypes.DWORD, + wintypes.DWORD, + ctypes.POINTER(wintypes.BYTE), + ctypes.POINTER(wintypes.DWORD), + ctypes.POINTER(wintypes.DWORD), + ctypes.POINTER(wintypes.DWORD)] +CryptStringToBinaryA = crypt32.CryptStringToBinaryA + +crypt32.CertAddEncodedCertificateToStore.restype = wintypes.BOOL +crypt32.CertAddEncodedCertificateToStore.argtypes = [ + wintypes.HANDLE, + wintypes.DWORD, + ctypes.POINTER(wintypes.BYTE), + wintypes.DWORD, + wintypes.DWORD, + ctypes.POINTER(ctypes.POINTER(CERT_CONTEXT))] +CertAddEncodedCertificateToStore = crypt32.CertAddEncodedCertificateToStore + +crypt32.CertGetNameStringW.restype = wintypes.DWORD +crypt32.CertGetNameStringW.argtypes = [ctypes.POINTER(CERT_CONTEXT), + wintypes.DWORD, + wintypes.DWORD, + ctypes.c_void_p, + wintypes.LPWSTR, + wintypes.DWORD] +CertGetNameString = crypt32.CertGetNameStringW + +crypt32.CertFreeCertificateContext.restype = wintypes.BOOL +crypt32.CertFreeCertificateContext.argtypes = [ctypes.POINTER(CERT_CONTEXT)] +CertFreeCertificateContext = crypt32.CertFreeCertificateContext + +crypt32.CertCloseStore.restype = wintypes.BOOL +crypt32.CertCloseStore.argtypes = [wintypes.HANDLE, wintypes.DWORD] +CertCloseStore = crypt32.CertCloseStore + +crypt32.CertGetCertificateContextProperty.restype = wintypes.BOOL +crypt32.CertGetCertificateContextProperty.argtypes = [ + ctypes.POINTER(CERT_CONTEXT), + wintypes.DWORD, + ctypes.c_void_p, + ctypes.POINTER(wintypes.DWORD)] +CertGetCertificateContextProperty = crypt32.CertGetCertificateContextProperty diff --git a/cloudbaseinit/plugins/windows/winrmconfig.py b/cloudbaseinit/plugins/windows/winrmconfig.py new file mode 100644 index 00000000..e109e76e --- /dev/null +++ b/cloudbaseinit/plugins/windows/winrmconfig.py @@ -0,0 +1,155 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 Cloudbase Solutions Srl +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import pywintypes +import re + +from win32com import client +from xml.etree import ElementTree + + +CBT_HARDENING_LEVEL_NONE = "none" +CBT_HARDENING_LEVEL_RELAXED = "relaxed" +CBT_HARDENING_LEVEL_STRICT = "strict" + +LISTENER_PROTOCOL_HTTP = "HTTP" +LISTENER_PROTOCOL_HTTPS = "HTTPS" + + +class WinRMConfig(object): + _SERVICE_AUTH_URI = 'winrm/Config/Service/Auth' + _SERVICE_LISTENER_URI = 'winrm/Config/Listener?Address=*+Transport=%s' + + def _get_wsman_session(self): + wsman = client.Dispatch('WSMan.Automation') + return wsman.CreateSession() + + def _get_node_tag(self, tag): + return re.match("^{.*}(.*)$", tag).groups(1)[0] + + def _parse_listener_xml(self, data_xml): + listening_on = [] + data = {"ListeningOn": listening_on} + + ns = {'cfg': + 'http://schemas.microsoft.com/wbem/wsman/1/config/listener'} + tree = ElementTree.fromstring(data_xml) + for node in tree: + tag = self._get_node_tag(node.tag) + if tag == "ListeningOn": + listening_on.append(node.text) + elif tag == "Enabled": + if node.text == "true": + value = True + else: + value = False + data[tag] = value + elif tag == "Port": + data[tag] = int(node.text) + else: + data[tag] = node.text + + return data + + def get_listener(self, protocol=LISTENER_PROTOCOL_HTTPS): + session = self._get_wsman_session() + resourceUri = self._SERVICE_LISTENER_URI % protocol + try: + data_xml = session.Get(resourceUri) + except pywintypes.com_error, ex: + if len(ex.excepinfo) > 5 and ex.excepinfo[5] == -2144108544: + return None + else: + raise + + return self._parse_listener_xml(data_xml) + + def delete_listener(self, protocol=LISTENER_PROTOCOL_HTTPS): + session = self._get_wsman_session() + resourceUri = self._SERVICE_LISTENER_URI % protocol + session.Delete(resourceUri) + + def create_listener(self, protocol=LISTENER_PROTOCOL_HTTPS, enabled=True, + cert_thumbprint=None): + session = self._get_wsman_session() + resource_uri = self._SERVICE_LISTENER_URI % protocol + + if enabled: + enabled_str = "true" + else: + enabled_str = "false" + + session.Create( + resource_uri, + '' + '%(enabled_str)s' + '%(cert_thumbprint)s' + '' + 'wsman' + '' % {"enabled_str": enabled_str, + "cert_thumbprint": cert_thumbprint}) + + def get_auth_config(self): + data = {} + + session = self._get_wsman_session() + data_xml = session.Get(self._SERVICE_AUTH_URI) + tree = ElementTree.fromstring(data_xml) + for node in tree: + tag = self._get_node_tag(node.tag) + value_str = node.text.lower() + if value_str == "true": + value = True + elif value_str == "false": + value = False + else: + value = value_str + data[tag] = value + + return data + + def set_auth_config(self, basic=None, kerberos=None, negotiate=None, + certificate=None, credSSP=None, + cbt_hardening_level=None): + + tag_map = {'Basic': basic, + 'Kerberos': kerberos, + 'Negotiate': negotiate, + 'Certificate': certificate, + 'CredSSP': credSSP, + 'CbtHardeningLevel': cbt_hardening_level} + + session = self._get_wsman_session() + data_xml = session.Get(self._SERVICE_AUTH_URI) + + ns = {'cfg': + 'http://schemas.microsoft.com/wbem/wsman/1/config/service/auth'} + tree = ElementTree.fromstring(data_xml) + + for (tag, value) in tag_map.items(): + if value is not None: + if value: + new_value = "true" + else: + new_value = "false" + + node = tree.find('.//cfg:%s' % tag, namespaces=ns) + + if node.text.lower() != new_value: + node.text = new_value + data_xml = ElementTree.tostring(tree) + session.Put(self._SERVICE_AUTH_URI, data_xml) diff --git a/cloudbaseinit/plugins/windows/winrmlistener.py b/cloudbaseinit/plugins/windows/winrmlistener.py new file mode 100644 index 00000000..49eb6703 --- /dev/null +++ b/cloudbaseinit/plugins/windows/winrmlistener.py @@ -0,0 +1,91 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 Cloudbase Solutions Srl +# +# Licensed under the Apache License, Version 2.0 (the 'License'); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from cloudbaseinit.openstack.common import cfg +from cloudbaseinit.openstack.common import log as logging +from cloudbaseinit.osutils import factory as osutils_factory +from cloudbaseinit.plugins import base +from cloudbaseinit.plugins.windows import x509 +from cloudbaseinit.plugins.windows import winrmconfig + +LOG = logging.getLogger(__name__) + +opts = [ + cfg.BoolOpt('winrm_enable_basic_auth', default=True, + help='Enables basic authentication for the WinRM ' + 'HTTPS listener'), +] + +CONF = cfg.CONF +CONF.register_opts(opts) + +LOG = logging.getLogger(__name__) + + +class ConfigWinRMListenerPlugin(base.BasePlugin): + _cert_subject = "CN=Cloudbase-Init" + _winrm_service_name = "WinRM" + + def _check_winrm_service(self, osutils): + if not osutils.check_service_exists(self._winrm_service_name): + LOG.warn("Cannot configure the WinRM listener as the service " + "is not available") + return False + + start_mode = osutils.get_service_start_mode(self._winrm_service_name) + if start_mode in [osutils.SERVICE_START_MODE_MANUAL, + osutils.SERVICE_START_MODE_DISABLED]: + # TODO(alexpilotti) Set to "Delayed Start" + osutils.set_service_start_mode( + self._winrm_service_name, + osutils.SERVICE_START_MODE_AUTOMATIC) + + service_status = osutils.get_service_status(self._winrm_service_name) + if service_status == osutils.SERVICE_STATUS_STOPPED: + osutils.start_service(self._winrm_service_name) + + return True + + def execute(self, service): + osutils = osutils_factory.OSUtilsFactory().get_os_utils() + + if not self._check_winrm_service(osutils): + return (base.PLUGIN_EXECUTE_ON_NEXT_BOOT, False) + + winrm_config = winrmconfig.WinRMConfig() + winrm_config.set_auth_config(basic=CONF.winrm_enable_basic_auth) + + cert_manager = x509.CryptoAPICertManager() + cert_thumbprint = cert_manager.create_self_signed_cert( + self._cert_subject) + + protocol = winrmconfig.LISTENER_PROTOCOL_HTTPS + + if winrm_config.get_listener(protocol=protocol): + winrm_config.delete_listener(protocol=protocol) + + winrm_config.create_listener( + cert_thumbprint=cert_thumbprint, + protocol=protocol) + + listener_config = winrm_config.get_listener(protocol=protocol) + listener_port = listener_config.get("Port") + + rule_name = "WinRM %s" % protocol + osutils.firewall_create_rule(rule_name, listener_port, + osutils.PROTOCOL_TCP) + + return (base.PLUGIN_EXECUTION_DONE, False) diff --git a/cloudbaseinit/plugins/windows/x509.py b/cloudbaseinit/plugins/windows/x509.py new file mode 100644 index 00000000..8f378e9e --- /dev/null +++ b/cloudbaseinit/plugins/windows/x509.py @@ -0,0 +1,285 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 Cloudbase Solutions Srl +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import copy +import ctypes +import uuid + +from ctypes import wintypes + +from cloudbaseinit.plugins.windows import cryptoapi + +malloc = ctypes.cdll.msvcrt.malloc +malloc.restype = ctypes.c_void_p +malloc.argtypes = [ctypes.c_size_t] + +free = ctypes.cdll.msvcrt.free +free.restype = None +free.argtypes = [ctypes.c_void_p] + + +class CryptoAPICertManager(object): + def _get_cert_thumprint(self, cert_context_p): + thumbprint = None + + try: + thumprint_len = wintypes.DWORD() + + if not cryptoapi.CertGetCertificateContextProperty( + cert_context_p, + cryptoapi.CERT_SHA1_HASH_PROP_ID, + None, ctypes.byref(thumprint_len)): + raise cryptoapi.CryptoAPIException() + + thumbprint = malloc(thumprint_len) + + if not cryptoapi.CertGetCertificateContextProperty( + cert_context_p, + cryptoapi.CERT_SHA1_HASH_PROP_ID, + thumbprint, ctypes.byref(thumprint_len)): + raise cryptoapi.CryptoAPIException() + + thumbprint_ar = ctypes.cast( + thumbprint, + ctypes.POINTER(ctypes.c_ubyte * + thumprint_len.value)).contents + + thumbprint_str = "" + for b in thumbprint_ar: + thumbprint_str += "%02x" % b + return thumbprint_str + finally: + if thumbprint: + free(thumbprint) + + def _generate_key(self, container_name, machine_keyset): + crypt_prov_handle = wintypes.HANDLE() + key_handle = wintypes.HANDLE() + + try: + flags = 0 + if machine_keyset: + flags |= cryptoapi.CRYPT_MACHINE_KEYSET + + if not cryptoapi.CryptAcquireContext( + ctypes.byref(crypt_prov_handle), + container_name, + None, + cryptoapi.PROV_RSA_FULL, + flags): + flags |= cryptoapi.CRYPT_NEWKEYSET + if not cryptoapi.CryptAcquireContext( + ctypes.byref(crypt_prov_handle), + container_name, + None, + cryptoapi.PROV_RSA_FULL, + flags): + raise cryptoapi.CryptoAPIException() + + # RSA 2048 bits + if not cryptoapi.CryptGenKey(crypt_prov_handle, + cryptoapi.AT_SIGNATURE, + 0x08000000, key_handle): + raise cryptoapi.CryptoAPIException() + finally: + if key_handle: + cryptoapi.CryptDestroyKey(key_handle) + if crypt_prov_handle: + cryptoapi.CryptReleaseContext(crypt_prov_handle, 0) + + def create_self_signed_cert(self, subject, validity_years=10, + machine_keyset=True, store_name="MY"): + subject_encoded = None + cert_context_p = None + store_handle = None + + container_name = str(uuid.uuid4()) + self._generate_key(container_name, machine_keyset) + + try: + subject_encoded_len = wintypes.DWORD() + + if not cryptoapi.CertStrToName(cryptoapi.X509_ASN_ENCODING, + subject, + cryptoapi.CERT_X500_NAME_STR, None, + None, + ctypes.byref(subject_encoded_len), + None): + raise cryptoapi.CryptoAPIException() + + subject_encoded = ctypes.cast(malloc(subject_encoded_len), + ctypes.POINTER(wintypes.BYTE)) + + if not cryptoapi.CertStrToName(cryptoapi.X509_ASN_ENCODING, + subject, + cryptoapi.CERT_X500_NAME_STR, None, + subject_encoded, + ctypes.byref(subject_encoded_len), + None): + raise cryptoapi.CryptoAPIException() + + subject_blob = cryptoapi.CRYPTOAPI_BLOB() + subject_blob.cbData = subject_encoded_len + subject_blob.pbData = subject_encoded + + key_prov_info = cryptoapi.CRYPT_KEY_PROV_INFO() + key_prov_info.pwszContainerName = container_name + key_prov_info.pwszProvName = None + key_prov_info.dwProvType = cryptoapi.PROV_RSA_FULL + key_prov_info.cProvParam = None + key_prov_info.rgProvParam = None + key_prov_info.dwKeySpec = cryptoapi.AT_SIGNATURE + + if machine_keyset: + key_prov_info.dwFlags = cryptoapi.CRYPT_MACHINE_KEYSET + else: + key_prov_info.dwFlags = 0 + + sign_alg = cryptoapi.CRYPT_ALGORITHM_IDENTIFIER() + sign_alg.pszObjId = cryptoapi.szOID_RSA_SHA1RSA + + start_time = cryptoapi.SYSTEMTIME() + cryptoapi.GetSystemTime(ctypes.byref(start_time)) + + end_time = copy.copy(start_time) + end_time.wYear += validity_years + + cert_context_p = cryptoapi.CertCreateSelfSignCertificate( + None, ctypes.byref(subject_blob), 0, + ctypes.byref(key_prov_info), + ctypes.byref(sign_alg), ctypes.byref(start_time), + ctypes.byref(end_time), None) + if not cert_context_p: + raise cryptoapi.CryptoAPIException() + + if not cryptoapi.CertAddEnhancedKeyUsageIdentifier( + cert_context_p, cryptoapi.szOID_PKIX_KP_SERVER_AUTH): + raise cryptoapi.CryptoAPIException() + + if machine_keyset: + flags = cryptoapi.CERT_SYSTEM_STORE_LOCAL_MACHINE + else: + flags = cryptoapi.CERT_SYSTEM_STORE_CURRENT_USER + + store_handle = cryptoapi.CertOpenStore( + cryptoapi.CERT_STORE_PROV_SYSTEM, 0, 0, flags, + unicode(store_name)) + if not store_handle: + raise cryptoapi.CryptoAPIException() + + if not cryptoapi.CertAddCertificateContextToStore( + store_handle, cert_context_p, + cryptoapi.CERT_STORE_ADD_REPLACE_EXISTING, None): + raise cryptoapi.CryptoAPIException() + + return self._get_cert_thumprint(cert_context_p) + + finally: + if store_handle: + cryptoapi.CertCloseStore(store_handle, 0) + if cert_context_p: + cryptoapi.CertFreeCertificateContext(cert_context_p) + if subject_encoded: + free(subject_encoded) + + def _get_cert_base64(self, cert_data): + header = "-----BEGIN CERTIFICATE-----\n" + footer = "-----END CERTIFICATE-----\n" + + base64_cert_data = cert_data + if base64_cert_data.startswith(header): + base64_cert_data = base64_cert_data[len(header):] + if base64_cert_data.endswith(footer): + base64_cert_data = base64_cert_data[:len(base64_cert_data) - + len(footer)] + return base64_cert_data.replace("\n", "") + + def import_cert(self, cert_data, machine_keyset=True, + store_name="TrustedPeople"): + + base64_cert_data = self._get_cert_base64(cert_data) + + cert_encoded = None + store_handle = None + cert_context_p = None + + try: + cert_encoded_len = wintypes.DWORD() + + if not cryptoapi.CryptStringToBinaryA( + base64_cert_data, len(base64_cert_data), + cryptoapi.CRYPT_STRING_BASE64, + None, ctypes.byref(cert_encoded_len), + None, None): + raise cryptoapi.CryptoAPIException() + + cert_encoded = ctypes.cast(malloc(cert_encoded_len), + ctypes.POINTER(wintypes.BYTE)) + + if not cryptoapi.CryptStringToBinaryA( + base64_cert_data, len(base64_cert_data), + cryptoapi.CRYPT_STRING_BASE64, + cert_encoded, ctypes.byref(cert_encoded_len), + None, None): + raise cryptoapi.CryptoAPIException() + + if machine_keyset: + flags = cryptoapi.CERT_SYSTEM_STORE_LOCAL_MACHINE + else: + flags = cryptoapi.CERT_SYSTEM_STORE_CURRENT_USER + + store_handle = cryptoapi.CertOpenStore( + cryptoapi.CERT_STORE_PROV_SYSTEM, 0, 0, flags, + unicode(store_name)) + if not store_handle: + raise cryptoapi.CryptoAPIException() + + cert_context_p = ctypes.POINTER(cryptoapi.CERT_CONTEXT)() + + if not cryptoapi.CertAddEncodedCertificateToStore( + store_handle, + cryptoapi.X509_ASN_ENCODING | + cryptoapi.PKCS_7_ASN_ENCODING, + cert_encoded, cert_encoded_len, + cryptoapi.CERT_STORE_ADD_REPLACE_EXISTING, + ctypes.byref(cert_context_p)): + raise cryptoapi.CryptoAPIException() + + # Get the UPN (1.3.6.1.4.1.311.20.2.3 OID) from the + # certificate subject alt name + upn = None + upn_len = cryptoapi.CertGetNameString( + cert_context_p, + cryptoapi.CERT_NAME_UPN_TYPE, 0, + None, None, 0) + if upn_len > 1: + upn_ar = ctypes.create_unicode_buffer(upn_len) + if cryptoapi.CertGetNameString( + cert_context_p, + cryptoapi.CERT_NAME_UPN_TYPE, + 0, None, upn_ar, upn_len) != upn_len: + raise cryptoapi.CryptoAPIException() + upn = upn_ar.value + + thumbprint = self._get_cert_thumprint(cert_context_p) + return (thumbprint, upn) + finally: + if cert_context_p: + cryptoapi.CertFreeCertificateContext(cert_context_p) + if store_handle: + cryptoapi.CertCloseStore(store_handle, 0) + if cert_encoded: + free(cert_encoded)