diff --git a/pylintrc b/pylintrc index 99ed3a7..02f5209 100644 --- a/pylintrc +++ b/pylintrc @@ -9,3 +9,7 @@ no-docstring-rgx=((__.*__)|([tT]est.*)|setUp|tearDown)$ [Design] min-public-methods=0 max-args=6 + +[Master] +#We try to keep contrib files unmodified +ignore=satori/contrib \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8f66f4d..779e2c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,14 @@ +# slightly improved version of impacket that allows you to use impacket.examples.serviceinstall +# to create services that are not randomly named +-e git://github.com/nick-o/impacket.git@e4fcac42975fd2f20d9ae6e8643c0fd9fab33c7a#egg=impacket + ipaddress>=1.0.6 # in stdlib as of python3.3 iso8601>=0.1.5 Jinja2>=2.7.1 # bug resolve @2.7.1 paramiko>=1.12.0 # ecdsa added pbr>=0.5.21,<1.0 python-novaclient>=2.6.0.1 # breaks before -pythonwhois>=2.0.0 +pythonwhois>=2.4.3 six>=1.4.0 # urllib introduced tldextract>=1.2 argparse diff --git a/satori/__init__.py b/satori/__init__.py index c5a2d94..989ced3 100644 --- a/satori/__init__.py +++ b/satori/__init__.py @@ -14,6 +14,12 @@ __all__ = ('__version__') +try: + import eventlet + eventlet.monkey_patch() +except ImportError: + pass + import pbr.version from satori import shell diff --git a/satori/bash.py b/satori/bash.py index ca31a44..5744b7f 100644 --- a/satori/bash.py +++ b/satori/bash.py @@ -23,6 +23,7 @@ import shlex import subprocess from satori import errors +from satori import smb from satori import ssh from satori import utils @@ -33,16 +34,21 @@ class ShellMixin(object): """Handle platform detection and define execute command.""" - def execute(self, command, wd=None, with_exit_code=None): + def execute(self, command, **kwargs): """Execute a (shell) command on the target. :param command: Shell command to be executed :param with_exit_code: Include the exit_code in the return body. - :param wd: The child's current directory will be changed - to `wd` before it is executed. Note that this + :param cwd: The child's current directory will be changed + to `cwd` before it is executed. Note that this directory is not considered when searching the executable, so you can't specify the program's path relative to this argument + :returns: a dict with stdin, stdout, and + (optionally), the exit_code of the call + + See SSH.remote_execute(), SMB.remote_execute(), and + LocalShell.execute() for client-specific keyword arguments. """ pass @@ -88,6 +94,9 @@ class ShellMixin(object): Uses the platform_info property. """ + if hasattr(self, '_client'): + if isinstance(self._client, smb.SMBClient): + return True if not self.platform_info['dist']: raise errors.UndeterminedPlatform( 'Unable to determine whether the system is Windows based.') @@ -122,26 +131,28 @@ class LocalShell(ShellMixin): self._platform_info = utils.get_platform_info() return self._platform_info - def execute(self, command, wd=None, with_exit_code=None): + def execute(self, command, **kwargs): """Execute a command (containing no shell operators) locally. :param command: Shell command to be executed. :param with_exit_code: Include the exit_code in the return body. Default is False. - :param wd: The child's current directory will be changed - to `wd` before it is executed. Note that this + :param cwd: The child's current directory will be changed + to `cwd` before it is executed. Note that this directory is not considered when searching the executable, so you can't specify the program's path relative to this argument :returns: A dict with stdin, stdout, and (optionally) the exit code. """ + cwd = kwargs.get('cwd') + with_exit_code = kwargs.get('with_exit_code') spipe = subprocess.PIPE cmd = shlex.split(command) LOG.debug("Executing `%s` on local machine", command) result = subprocess.Popen( - cmd, stdout=spipe, stderr=spipe, cwd=wd) + cmd, stdout=spipe, stderr=spipe, cwd=cwd) out, err = result.communicate() resultdict = { 'stdout': out.strip(), @@ -159,7 +170,7 @@ class RemoteShell(ShellMixin): def __init__(self, address, password=None, username=None, private_key=None, key_filename=None, port=None, timeout=None, gateway=None, options=None, interactive=False, - **kwargs): + protocol='ssh', **kwargs): """An interface for executing shell commands on remote machines. :param str host: The ip address or host name of the server @@ -189,11 +200,20 @@ class RemoteShell(ShellMixin): LOG.warning("Satori RemoteClient received unrecognized " "keyword arguments: %s", kwargs.keys()) - self._client = ssh.connect( - address, password=password, username=username, - private_key=private_key, key_filename=key_filename, port=port, - timeout=timeout, gateway=gateway, options=options, - interactive=interactive) + if protocol == 'smb': + self._client = smb.connect(address, password=password, + username=username, + port=port, timeout=timeout, + gateway=gateway) + else: + self._client = ssh.connect(address, password=password, + username=username, + private_key=private_key, + key_filename=key_filename, + port=port, timeout=timeout, + gateway=gateway, + options=options, + interactive=interactive) self.host = self._client.host self.port = self._client.port diff --git a/satori/contrib/psexec.py b/satori/contrib/psexec.py new file mode 100644 index 0000000..fd335dc --- /dev/null +++ b/satori/contrib/psexec.py @@ -0,0 +1,558 @@ +#!/usr/bin/python +# Copyright (c) 2003-2012 CORE Security Technologies +# +# This software is provided under under a slightly modified version +# of the Apache Software License. See the accompanying LICENSE file +# for more information. +# +# $Id: psexec.py 712 2012-09-06 04:26:22Z bethus@gmail.com $ +# +# PSEXEC like functionality example using +#RemComSvc (https://github.com/kavika13/RemCom) +# +# Author: +# beto (bethus@gmail.com) +# +# Reference for: +# DCE/RPC and SMB. + +""". + +OK +""" + + +import cmd +import os +import re +import sys + +#from impacket.smbconnection import * +from impacket.dcerpc import dcerpc +from impacket.dcerpc import transport +from impacket.examples import remcomsvc +from impacket.examples import serviceinstall +from impacket import smbconnection +from impacket import structure as im_structure +from impacket import version +#from impacket.dcerpc import dcerpc_v4 +#from impacket.dcerpc import srvsvc +#from impacket.dcerpc import svcctl +#from impacket.smbconnection import smb +#from impacket.smbconnection import SMB_DIALECT +#from impacket.smbconnection import SMBConnection + +import argparse +import random +import string +import threading +import time + + +class RemComMessage(im_structure.Structure): + + """.""" + + structure = ( + ('Command', '4096s=""'), + ('WorkingDir', '260s=""'), + ('Priority', ' 0: + try: + s.waitNamedPipe(tid, pipe) + pipeReady = True + except Exception: + tries -= 1 + time.sleep(2) + pass + + if tries == 0: + print('[!] Pipe not ready, aborting') + raise + + fid = s.openFile(tid, pipe, accessMask, creationOption=0x40, + fileAttributes=0x80) + + return fid + + def doStuff(self, rpctransport): + """.""" + dce = dcerpc.DCERPC_v5(rpctransport) + try: + dce.connect() + except Exception as e: + print(e) + sys.exit(1) + + global dialect + dialect = rpctransport.get_smb_connection().getDialect() + + try: + unInstalled = False + s = rpctransport.get_smb_connection() + + # We don't wanna deal with timeouts from now on. + s.setTimeout(100000) + svcName = "RackspaceSystemDiscovery" + executableName = "RackspaceSystemDiscovery.exe" + if self.__exeFile is None: + svc = remcomsvc.RemComSvc() + installService = serviceinstall.ServiceInstall(s, svc, + svcName, + executableName) + else: + try: + f = open(self.__exeFile) + except Exception as e: + print(e) + sys.exit(1) + installService = serviceinstall.ServiceInstall(s, f, + svcName, + executableName) + + installService.install() + + if self.__exeFile is not None: + f.close() + + tid = s.connectTree('IPC$') + fid_main = self.openPipe(s, tid, '\RemCom_communicaton', 0x12019f) + + packet = RemComMessage() + pid = os.getpid() + + packet['Machine'] = ''.join([random.choice(string.letters) + for i in range(4)]) + if self.__path is not None: + packet['WorkingDir'] = self.__path + packet['Command'] = self.__command + packet['ProcessID'] = pid + + s.writeNamedPipe(tid, fid_main, str(packet)) + + # Here we'll store the command we type so we don't print it back ;) + # ( I know.. globals are nasty :P ) + global LastDataSent + LastDataSent = '' + + retCode = None + # Create the pipes threads + stdin_pipe = RemoteStdInPipe(rpctransport, + '\%s%s%d' % (RemComSTDIN, + packet['Machine'], + packet['ProcessID']), + smbconnection.smb.FILE_WRITE_DATA | + smbconnection.smb.FILE_APPEND_DATA, + installService.getShare()) + stdin_pipe.start() + stdout_pipe = RemoteStdOutPipe(rpctransport, + '\%s%s%d' % (RemComSTDOUT, + packet['Machine'], + packet['ProcessID']), + smbconnection.smb.FILE_READ_DATA) + stdout_pipe.start() + stderr_pipe = RemoteStdErrPipe(rpctransport, + '\%s%s%d' % (RemComSTDERR, + packet['Machine'], + packet['ProcessID']), + smbconnection.smb.FILE_READ_DATA) + stderr_pipe.start() + + # And we stay here till the end + ans = s.readNamedPipe(tid, fid_main, 8) + + if len(ans): + retCode = RemComResponse(ans) + print("[*] Process %s finished with ErrorCode: %d, " + "ReturnCode: %d" % (self.__command, retCode['ErrorCode'], + retCode['ReturnCode'])) + installService.uninstall() + unInstalled = True + sys.exit(retCode['ReturnCode']) + + except Exception: + if unInstalled is False: + installService.uninstall() + sys.stdout.flush() + if retCode: + sys.exit(retCode['ReturnCode']) + else: + sys.exit(1) + + +class Pipes(threading.Thread): + + """.""" + + def __init__(self, transport, pipe, permissions, share=None): + """.""" + threading.Thread.__init__(self) + self.server = 0 + self.transport = transport + self.credentials = transport.get_credentials() + self.tid = 0 + self.fid = 0 + self.share = share + self.port = transport.get_dport() + self.pipe = pipe + self.permissions = permissions + self.daemon = True + + def connectPipe(self): + """.""" + try: + lock.acquire() + global dialect + + remoteHost = self.transport.get_smb_connection().getRemoteHost() + #self.server = SMBConnection('*SMBSERVER', + #self.transport.get_smb_connection().getRemoteHost(), + #sess_port = self.port, preferredDialect = SMB_DIALECT) + self.server = smbconnection.SMBConnection('*SMBSERVER', remoteHost, + sess_port=self.port, + preferredDialect=dialect) # noqa + user, passwd, domain, lm, nt = self.credentials + self.server.login(user, passwd, domain, lm, nt) + lock.release() + self.tid = self.server.connectTree('IPC$') + + self.server.waitNamedPipe(self.tid, self.pipe) + self.fid = self.server.openFile(self.tid, self.pipe, + self.permissions, + creationOption=0x40, + fileAttributes=0x80) + self.server.setTimeout(1000000) + except Exception: + message = ("[!] Something wen't wrong connecting the pipes(%s), " + "try again") + print(message % self.__class__) + + +class RemoteStdOutPipe(Pipes): + + """.""" + + def __init__(self, transport, pipe, permisssions): + """.""" + Pipes.__init__(self, transport, pipe, permisssions) + + def run(self): + """.""" + self.connectPipe() + while True: + try: + ans = self.server.readFile(self.tid, self.fid, 0, 1024) + except Exception: + pass + else: + try: + global LastDataSent + if ans != LastDataSent: # noqa + sys.stdout.write(ans) + sys.stdout.flush() + else: + # Don't echo what I sent, and clear it up + LastDataSent = '' + # Just in case this got out of sync, i'm cleaning it + # up if there are more than 10 chars, + # it will give false positives tho.. we should find a + # better way to handle this. + if LastDataSent > 10: + LastDataSent = '' + except Exception: + pass + + +class RemoteStdErrPipe(Pipes): + + """.""" + + def __init__(self, transport, pipe, permisssions): + """.""" + Pipes.__init__(self, transport, pipe, permisssions) + + def run(self): + """.""" + self.connectPipe() + while True: + try: + ans = self.server.readFile(self.tid, self.fid, 0, 1024) + except Exception: + pass + else: + try: + sys.stderr.write(str(ans)) + sys.stderr.flush() + except Exception: + pass + + +class RemoteShell(cmd.Cmd): + + """.""" + + def __init__(self, server, port, credentials, tid, fid, share): + """.""" + cmd.Cmd.__init__(self, False) + self.prompt = '\x08' + self.server = server + self.transferClient = None + self.tid = tid + self.fid = fid + self.credentials = credentials + self.share = share + self.port = port + self.intro = '[!] Press help for extra shell commands' + + def connect_transferClient(self): + """.""" + #self.transferClient = SMBConnection('*SMBSERVER', + #self.server.getRemoteHost(), sess_port = self.port, + #preferredDialect = SMB_DIALECT) + self.transferClient = smbconnection.SMBConnection('*SMBSERVER', + self.server.getRemoteHost(), + sess_port=self.port, + preferredDialect=dialect) # noqa + user, passwd, domain, lm, nt = self.credentials + self.transferClient.login(user, passwd, domain, lm, nt) + + def do_help(self, line): + """.""" + print(""" + lcd {path} - changes the current local directory to {path} + exit - terminates the server process (and this session) + put {src_file, dst_path} - uploads a local file to the dst_path RELATIVE to + the connected share (%s) + get {file} - downloads pathname RELATIVE to the connected + share (%s) to the current local dir + ! {cmd} - executes a local shell cmd +""" % (self.share, self.share)) + self.send_data('\r\n', False) + + def do_shell(self, s): + """.""" + os.system(s) + self.send_data('\r\n') + + def do_get(self, src_path): + """.""" + try: + if self.transferClient is None: + self.connect_transferClient() + + import ntpath + filename = ntpath.basename(src_path) + fh = open(filename, 'wb') + print("[*] Downloading %s\%s" % (self.share, src_path)) + self.transferClient.getFile(self.share, src_path, fh.write) + fh.close() + except Exception as e: + print(e) + pass + + self.send_data('\r\n') + + def do_put(self, s): + """.""" + try: + if self.transferClient is None: + self.connect_transferClient() + params = s.split(' ') + if len(params) > 1: + src_path = params[0] + dst_path = params[1] + elif len(params) == 1: + src_path = params[0] + dst_path = '/' + + src_file = os.path.basename(src_path) + fh = open(src_path, 'rb') + f = dst_path + '/' + src_file + pathname = string.replace(f, '/', '\\') + print("[*] Uploading %s to %s\\%s" % (src_file, self.share, + dst_path)) + self.transferClient.putFile(self.share, pathname, fh.read) + fh.close() + except Exception as e: + print(e) + pass + + self.send_data('\r\n') + + def do_lcd(self, s): + """.""" + if s == '': + print(os.getcwd()) + else: + os.chdir(s) + self.send_data('\r\n') + + def emptyline(self): + """.""" + self.send_data('\r\n') + return + + def do_EOF(self, line): + """.""" + self.server.logoff() + + def default(self, line): + """.""" + self.send_data(line+'\r\n') + + def send_data(self, data, hideOutput=True): + """.""" + if hideOutput is True: + global LastDataSent + LastDataSent = data + else: + LastDataSent = '' + self.server.writeFile(self.tid, self.fid, data) + + +class RemoteStdInPipe(Pipes): + + """RemoteStdInPipe class. + + Used to connect to RemComSTDIN named pipe on remote system + """ + + def __init__(self, transport, pipe, permisssions, share=None): + """Constructor.""" + Pipes.__init__(self, transport, pipe, permisssions, share) + + def run(self): + """.""" + self.connectPipe() + self.shell = RemoteShell(self.server, self.port, self.credentials, + self.tid, self.fid, self.share) + self.shell.cmdloop() + + +# Process command-line arguments. +if __name__ == '__main__': + print(version.BANNER) + + parser = argparse.ArgumentParser() + + parser.add_argument('target', action='store', + help='[domain/][username[:password]@]
') + parser.add_argument('command', action='store', + help='command to execute at the target (w/o path)') + parser.add_argument('-path', action='store', + help='path of the command to execute') + parser.add_argument( + '-file', action='store', + help="alternative RemCom binary (be sure it doesn't require CRT)") + parser.add_argument( + '-port', action='store', + help='alternative port to use, this will copy settings from 445/SMB') + parser.add_argument('protocol', choices=PSEXEC.KNOWN_PROTOCOLS.keys(), + nargs='?', default='445/SMB', + help='transport protocol (default 445/SMB)') + + group = parser.add_argument_group('authentication') + + group.add_argument('-hashes', action="store", metavar="LMHASH:NTHASH", + help='NTLM hashes, format is LMHASH:NTHASH') + + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + + options = parser.parse_args() + + domain, username, password, address = re.compile( + '(?:(?:([^/@:]*)/)?([^@:]*)(?::([^.]*))?@)?(.*)' + ).match(options.target).groups('') + + if domain is None: + domain = '' + + if options.port: + options.protocol = "%s/SMB" % options.port + + executer = PSEXEC(options.command, options.path, options.file, + options.protocol, username, password, domain, + options.hashes) + + if options.protocol not in PSEXEC.KNOWN_PROTOCOLS: + connection_string = 'ncacn_np:%s[\\pipe\\svcctl]' + PSEXEC.KNOWN_PROTOCOLS[options.protocol] = (connection_string, + options.port) + + executer.run(address) diff --git a/satori/discovery.py b/satori/discovery.py index f0864db..9c98d99 100644 --- a/satori/discovery.py +++ b/satori/discovery.py @@ -56,9 +56,9 @@ def run(target, config=None, interactive=False): found['hostname'] = hostname ip_address = six.text_type(dns.resolve_hostname(hostname)) # TODO(sam): Use ipaddress.ip_address.is_global - # " .is_private - # " .is_unspecified - # " .is_multicast + # .is_private + # .is_unspecified + # .is_multicast # To determine address "type" if not ipaddress.ip_address(ip_address).is_loopback: try: diff --git a/satori/smb.py b/satori/smb.py new file mode 100644 index 0000000..6839ada --- /dev/null +++ b/satori/smb.py @@ -0,0 +1,316 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Windows remote client module implemented using psexec.py.""" + +try: + import eventlet + eventlet.monkey_patch() + from eventlet.green import time +except ImportError: + import time + +import ast +import base64 +import logging +import os +import re +import shlex +import subprocess +import tempfile + +from satori import ssh +from satori import tunnel + +LOG = logging.getLogger(__name__) + + +def connect(*args, **kwargs): + """Connect to a remote device using psexec.py.""" + try: + return SMBClient.get_client(*args, **kwargs) + except Exception as exc: + LOG.error("ERROR: pse.py failed to connect: %s", str(exc)) + + +def _posh_encode(command): + """Encode a powershell command to base64. + + This is using utf-16 encoding and disregarding the first two bytes + :param command: command to encode + """ + return base64.b64encode(command.encode('utf-16')[2:]) + + +class SubprocessError(Exception): + + """Custom Exception. + + This will be raised when the subprocess running psexec.py has exited. + """ + + pass + + +class SMBClient(object): # pylint: disable=R0902 + + """Connects to devices over SMB/psexec to execute commands.""" + + _prompt_pattern = re.compile(r'^[a-zA-Z]:\\.*>$', re.MULTILINE) + + # pylint: disable=R0913 + def __init__(self, host, password=None, username="Administrator", + port=445, timeout=10, gateway=None, **kwargs): + """Create an instance of the PSE class. + + :param str host: The ip address or host name of the server + to connect to + :param str password: A password to use for authentication + :param str username: The username to authenticate as (defaults to + Administrator) + :param int port: tcp/ip port to use (defaults to 445) + :param float timeout: an optional timeout (in seconds) for the + TCP connection + :param gateway: instance of satori.ssh.SSH to be used to set up + an SSH tunnel (equivalent to ssh -L) + """ + self.password = password + self.host = host + self.port = port or 445 + self.username = username or 'Administrator' + self.timeout = timeout + self._connected = False + self._platform_info = None + self._process = None + self._orig_host = None + self._orig_port = None + self.ssh_tunnel = None + self._substituted_command = None + + # creating temp file to talk to _process with + self._file_write = tempfile.NamedTemporaryFile() + self._file_read = open(self._file_write.name, 'r') + + self._command = ("nice python %s/contrib/psexec.py -port %s %s:%s@%s " + "'c:\\Windows\\sysnative\\cmd'") + self._output = '' + self.gateway = gateway + + if gateway: + if not isinstance(self.gateway, ssh.SSH): + raise TypeError("'gateway' must be a satori.ssh.SSH instance. " + "( instances of this type are returned by" + "satori.ssh.connect() )") + + if kwargs: + LOG.debug("DEBUG: Following arguments passed into PSE constructor " + "not used: %s", kwargs.keys()) + + def __del__(self): + """Destructor of the PSE class.""" + try: + self.close() + except ValueError: + pass + + @classmethod + def get_client(cls, *args, **kwargs): + """Return a pse client object from this module.""" + return cls(*args, **kwargs) + + @property + def platform_info(self): + """Return Windows edition, version and architecture. + + requires Powershell version 3 + """ + if not self._platform_info: + command = ('Get-WmiObject Win32_OperatingSystem |' + ' select @{n="dist";e={$_.Caption.Trim()}},' + '@{n="version";e={$_.Version}},@{n="arch";' + 'e={$_.OSArchitecture}} | ' + ' ConvertTo-Json -Compress') + stdout = self.remote_execute(command, retry=3) + self._platform_info = ast.literal_eval(stdout) + + return self._platform_info + + def create_tunnel(self): + """Create an ssh tunnel via gateway. + + This will tunnel a local ephemeral port to the host's port. + This will preserve the original host and port + """ + self.ssh_tunnel = tunnel.connect(self.host, self.port, self.gateway) + self._orig_host = self.host + self._orig_port = self.port + self.host, self.port = self.ssh_tunnel.address + self.ssh_tunnel.serve_forever(async=True) + + def shutdown_tunnel(self): + """Terminate the ssh tunnel. Restores original host and port.""" + self.ssh_tunnel.shutdown() + self.host = self._orig_host + self.port = self._orig_port + + def test_connection(self): + """Connect to a Windows server and disconnect again. + + Make sure the returncode is 0, otherwise return False + """ + self.connect() + self.close() + self._get_output() + if self._output.find('ErrorCode: 0, ReturnCode: 0') > -1: + return True + else: + return False + + def connect(self): + """Attempt a connection using psexec.py. + + This will create a subprocess.Popen() instance and communicate with it + via _file_read/_file_write and _process.stdin + """ + try: + if self._connected and self._process: + if self._process.poll() is None: + return + else: + self._process.wait() + if self.gateway: + self.shutdown_tunnel() + if self.gateway: + self.create_tunnel() + self._substituted_command = self._command % ( + os.path.dirname(__file__), + self.port, + self.username, + self.password, + self.host) + self._process = subprocess.Popen( + shlex.split(self._substituted_command), + stdout=self._file_write, + stderr=subprocess.STDOUT, + stdin=subprocess.PIPE, + close_fds=True, + bufsize=0) + output = '' + while not self._prompt_pattern.findall(output): + output += self._get_output() + self._connected = True + except Exception: + self.close() + raise + + def close(self): + """Close the psexec connection by sending 'exit' to the subprocess. + + This will cleanly exit psexec (i.e. stop and uninstall the service and + delete the files) + + This method will be called when an instance of this class is about to + being destroyed. It will try to close the connection (which will clean + up on the remote server) and catch the exception that is raised when + the connection has already been closed. + """ + try: + self._process.communicate('exit') + except Exception as exc: + LOG.warning("ERROR: Failed to close %s: %s", self, str(exc)) + del exc + try: + if self.gateway: + self.shutdown_tunnel() + self.gateway.close() + except Exception as exc: + LOG.warning("ERROR: Failed to close gateway %s: %s", self.gateway, + str(exc)) + del exc + finally: + if self._process: + LOG.warning("Killing process: %s", self._process) + subprocess.call(['pkill', '-STOP', '-P', + str(self._process.pid)]) + + def remote_execute(self, command, powershell=True, retry=0, **kwargs): + """Execute a command on a remote host. + + :param command: Command to be executed + :param powershell: If True, command will be interpreted as Powershell + command and therefore converted to base64 and + prepended with 'powershell -EncodedCommand + :param int retry: Number of retries when SubprocessError is thrown + by _get_output before giving up + """ + self.connect() + if powershell: + command = ('powershell -EncodedCommand %s' % + _posh_encode(command)) + self._process.stdin.write('%s\n' % command) + try: + output = self._get_output() + output = "\n".join(output.splitlines()[:-1]).strip() + return output + except SubprocessError: + if not retry: + raise + else: + return self.remote_execute(command, powershell=powershell, + retry=retry - 1) + + def _get_output(self, prompt_expected=True, wait=200): + """Retrieve output from _process. + + This method will wait until output is started to be received and then + wait until no further output is received within a defined period + :param prompt_expected: only return when regular expression defined + in _prompt_pattern is matched + :param wait: Time in milliseconds to wait in each of the + two loops that wait for (more) output. + """ + tmp_out = '' + while tmp_out == '': + self._file_read.seek(0, 1) + tmp_out += self._file_read.read() + # leave loop if underlying process has a return code + # obviously meaning that it has terminated + if self._process.poll() is not None: + import json + error = {"error": tmp_out} + raise SubprocessError("subprocess with pid: %s has terminated " + "unexpectedly with return code: %s\n%s" + % (self._process.pid, + self._process.poll(), + json.dumps(error))) + time.sleep(wait/1000) + stdout = tmp_out + while (not tmp_out == '' or + (not self._prompt_pattern.findall(stdout) and + prompt_expected)): + self._file_read.seek(0, 1) + tmp_out = self._file_read.read() + stdout += tmp_out + # leave loop if underlying process has a return code + # obviously meaning that it has terminated + if self._process.poll() is not None: + import json + error = {"error": stdout} + raise SubprocessError("subprocess with pid: %s has terminated " + "unexpectedly with return code: %s\n%s" + % (self._process.pid, + self._process.poll(), + json.dumps(error))) + time.sleep(wait/1000) + self._output += stdout + stdout = stdout.replace('\r', '').replace('\x08', '') + return stdout diff --git a/satori/ssh.py b/satori/ssh.py index db39582..5a468e4 100644 --- a/satori/ssh.py +++ b/satori/ssh.py @@ -129,10 +129,10 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902 """ self.password = password self.host = host - self.username = username + self.username = username or 'root' self.private_key = private_key self.key_filename = key_filename - self.port = port + self.port = port or 22 self.timeout = timeout self._platform_info = None self.options = options or {} @@ -160,7 +160,6 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902 Requires >= Python 2.4 on remote system. """ if not self._platform_info: - platform_command = "import platform,sys\n" platform_command += utils.get_source_definition( utils.get_platform_info) @@ -169,10 +168,18 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902 command = 'echo -e """%s""" | python' % platform_command output = self.remote_execute(command) stdout = re.split('\n|\r\n', output['stdout'])[-1].strip() - plat = ast.literal_eval(stdout) + if stdout: + try: + plat = ast.literal_eval(stdout) + except SyntaxError as exc: + plat = {'dist': 'unknown'} + LOG.warning("Error parsing response from host '%s': %s", + self.host, output, exc_info=exc) + else: + plat = {'dist': 'unknown'} + LOG.warning("Blank response from host '%s': %s", + self.host, output) self._platform_info = plat - - LOG.debug("Remote platform info: %s", self._platform_info) return self._platform_info def connect_with_host_keys(self): @@ -362,7 +369,7 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902 return False def remote_execute(self, command, with_exit_code=False, - get_pty=False, wd=None): + get_pty=False, cwd=None, **kwargs): """Execute an ssh command on a remote host. Tries cert auth first and falls back @@ -370,8 +377,8 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902 :param command: Shell command to be executed by this function. :param with_exit_code: Include the exit_code in the return body. - :param wd: The child's current directory will be changed - to `wd` before it is executed. Note that this + :param cwd: The child's current directory will be changed + to `cwd` before it is executed. Note that this directory is not considered when searching the executable, so you can't specify the program's path relative to this argument @@ -380,8 +387,8 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902 :returns: a dict with stdin, stdout, and (optionally) the exit code of the call. """ - if wd: - prefix = "cd %s && " % wd + if cwd: + prefix = "cd %s && " % cwd command = prefix + command LOG.debug("Executing '%s' on ssh://%s@%s:%s.", @@ -408,10 +415,10 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902 'stderr': stderr.read() } - LOG.debug("STDOUT from ssh://%s@%s:%d: %s", + LOG.debug("STDOUT from ssh://%s@%s:%d: %.5000s ...", self.username, self.host, self.port, results['stdout']) - LOG.debug("STDERR from ssh://%s@%s:%d: %s", + LOG.debug("STDERR from ssh://%s@%s:%d: %.5000s ...", self.username, self.host, self.port, results['stderr']) exit_code = chan.recv_exit_status() diff --git a/satori/sysinfo/ohai_solo.py b/satori/sysinfo/ohai_solo.py index 29cd600..5dc62a4 100644 --- a/satori/sysinfo/ohai_solo.py +++ b/satori/sysinfo/ohai_solo.py @@ -24,10 +24,6 @@ from satori import errors from satori import utils LOG = logging.getLogger(__name__) -if six.PY3: - def unicode(text, errors=None): # noqa - """A hacky Python 3 version of unicode() function.""" - return str(text) def get_systeminfo(ipaddress, config, interactive=False): @@ -38,7 +34,7 @@ def get_systeminfo(ipaddress, config, interactive=False): :keyword interactive: whether to prompt the user for information. """ if (ipaddress in utils.get_local_ips() or - ipaddress_module.ip_address(unicode(ipaddress)).is_loopback): + ipaddress_module.ip_address(six.text_type(ipaddress)).is_loopback): client = bash.LocalShell() client.host = "localhost" @@ -66,46 +62,66 @@ def system_info(client): SystemInfoNotJson if `ohai` does not return valid JSON. SystemInfoMissingJson if `ohai` does not return any JSON. """ - output = client.execute("sudo -i ohai-solo") - not_found_msgs = ["command not found", "Could not find ohai"] - if any(m in k for m in not_found_msgs - for k in list(output.values()) if isinstance(k, six.string_types)): - LOG.warning("SystemInfoCommandMissing on host: [%s]", client.host) - raise errors.SystemInfoCommandMissing("ohai-solo missing on %s", - client.host) - unicode_output = unicode(output['stdout'], errors='replace') - try: - results = json.loads(unicode_output) - except ValueError as exc: + if client.is_windows(): + raise errors.UnsupportedPlatform( + "ohai-solo is a linux-only sytem info provider. " + "Target platform was %s", client.platform_info['dist']) + else: + output = client.execute("sudo -i ohai-solo") + not_found_msgs = ["command not found", "Could not find ohai"] + if any(m in k for m in not_found_msgs + for k in list(output.values()) if isinstance(k, + six.string_types)): + LOG.warning("SystemInfoCommandMissing on host: [%s]", client.host) + raise errors.SystemInfoCommandMissing("ohai-solo missing on %s" % + client.host) + # use string formatting to handle unicode + unicode_output = "%s" % output['stdout'] try: - clean_output = get_json(unicode_output) - results = json.loads(clean_output) + results = json.loads(unicode_output) except ValueError as exc: - raise errors.SystemInfoNotJson(exc) - return results + try: + clean_output = get_json(unicode_output) + results = json.loads(clean_output) + except ValueError as exc: + raise errors.SystemInfoNotJson(exc) + return results def install_remote(client): """Install ohai-solo on remote system.""" LOG.info("Installing (or updating) ohai-solo on device %s at %s:%d", client.host, client.host, client.port) - # Download to host - command = "sudo wget -N http://ohai.rax.io/install.sh" - client.execute(command, wd='/tmp') - # Run install - command = "sudo bash install.sh" - output = client.execute(command, wd='/tmp', with_exit_code=True) - - # Be a good citizen and clean up your tmp data - command = "sudo rm install.sh" - client.execute(command, wd='/tmp') - - # Process install command output - if output['exit_code'] != 0: - raise errors.SystemInfoCommandInstallFailed(output['stderr'][:256]) + # Check if it a windows box, but fail safely to Linux + is_windows = False + try: + is_windows = client.is_windows() + except Exception: + pass + if is_windows: + raise errors.UnsupportedPlatform( + "ohai-solo is a linux-only sytem info provider. " + "Target platform was %s", client.platform_info['dist']) else: - return output + # Download to host + command = "sudo wget -N http://ohai.rax.io/install.sh" + client.execute(command, cwd='/tmp') + + # Run install + command = "sudo bash install.sh" + output = client.execute(command, cwd='/tmp', with_exit_code=True) + + # Be a good citizen and clean up your tmp data + command = "sudo rm install.sh" + client.execute(command, cwd='/tmp') + + # Process install command output + if output['exit_code'] != 0: + raise errors.SystemInfoCommandInstallFailed( + output['stderr'][:256]) + else: + return output def remove_remote(client): @@ -117,24 +133,29 @@ def remove_remote(client): - redhat [5.x, 6.x] - centos [5.x, 6.x] """ - platform_info = client.platform_info - if client.is_debian(): - remove = "sudo dpkg --purge ohai-solo" - elif client.is_fedora(): - remove = "sudo yum -y erase ohai-solo" + if client.is_windows(): + raise errors.UnsupportedPlatform( + "ohai-solo is a linux-only sytem info provider. " + "Target platform was %s", client.platform_info['dist']) else: - raise errors.UnsupportedPlatform("Unknown distro: %s" % - platform_info['dist']) - command = "%s" % remove - output = client.execute(command, wd='/tmp') - return output + platform_info = client.platform_info + if client.is_debian(): + remove = "sudo dpkg --purge ohai-solo" + elif client.is_fedora(): + remove = "sudo yum -y erase ohai-solo" + else: + raise errors.UnsupportedPlatform("Unknown distro: %s" % + platform_info['dist']) + command = "%s" % remove + output = client.execute(command, cwd='/tmp') + return output def get_json(data): """Find the JSON string in data and return a string. :param data: :string: - :returns: string -- JSON string striped of non-JSON data + :returns: string -- JSON string stripped of non-JSON data :raises: SystemInfoMissingJson SystemInfoMissingJson if `ohai` does not return any JSON. diff --git a/satori/sysinfo/posh_ohai.py b/satori/sysinfo/posh_ohai.py new file mode 100644 index 0000000..605414e --- /dev/null +++ b/satori/sysinfo/posh_ohai.py @@ -0,0 +1,149 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# pylint: disable=W0622 +"""PoSh-Ohai Data Plane Discovery Module.""" + +import json +import logging + +import ipaddress as ipaddress_module +import six + +from satori import bash +from satori import errors +from satori import utils + +LOG = logging.getLogger(__name__) + + +def get_systeminfo(ipaddress, config, interactive=False): + """Run data plane discovery using this module against a host. + + :param ipaddress: address to the host to discover. + :param config: arguments and configuration suppplied to satori. + :keyword interactive: whether to prompt the user for information. + """ + if (ipaddress in utils.get_local_ips() or + ipaddress_module.ip_address(six.text_type(ipaddress)).is_loopback): + + client = bash.LocalShell() + client.host = "localhost" + client.port = 0 + + else: + client = bash.RemoteShell(ipaddress, username=config['host_username'], + private_key=config['host_key'], + interactive=interactive) + + install_remote(client) + return system_info(client) + + +def system_info(client): + """Run Posh-Ohai on a remote system and gather the output. + + :param client: :class:`smb.SMB` instance + :returns: dict -- system information from PoSh-Ohai + :raises: SystemInfoCommandMissing, SystemInfoCommandOld, SystemInfoNotJson + SystemInfoMissingJson + + SystemInfoCommandMissing if `posh-ohai` is not installed. + SystemInfoCommandOld if `posh-ohai` is not the latest. + SystemInfoNotJson if `posh-ohai` does not return valid JSON. + SystemInfoMissingJson if `posh-ohai` does not return any JSON. + """ + if client.is_windows(): + powershell_command = 'Get-ComputerConfiguration' + output = client.execute(powershell_command) + unicode_output = "%s" % output + try: + results = json.loads(unicode_output) + except ValueError: + try: + clean_output = get_json(unicode_output) + results = json.loads(clean_output) + except ValueError as err: + raise errors.SystemInfoNotJson(err) + return results + else: + raise errors.PlatformNotSupported( + "PoSh-Ohai is a Windows-only sytem info provider. " + "Target platform was %s", client.platform_info['dist']) + + +def install_remote(client): + """Install PoSh-Ohai on remote system.""" + LOG.info("Installing (or updating) PoSh-Ohai on device %s at %s:%d", + client.host, client.host, client.port) + + # Check is it is a windows box, but fail safely to Linux + is_windows = False + try: + is_windows = client.is_windows() + except Exception: + pass + if is_windows: + powershell_command = ('[scriptblock]::Create((New-Object -TypeName ' + 'System.Net.WebClient).DownloadString(' + '"http://ohai.rax.io/deploy.ps1"))' + '.Invoke()') + # check output to ensure that installation was successful + # if not, raise SystemInfoCommandInstallFailed + output = client.execute(powershell_command) + return output + else: + raise errors.PlatformNotSupported( + "PoSh-Ohai is a Windows-only sytem info provider. " + "Target platform was %s", client.platform_info['dist']) + + +def remove_remote(client): + """Remove PoSh-Ohai from specifc remote system. + + Currently supports: + - ubuntu [10.x, 12.x] + - debian [6.x, 7.x] + - redhat [5.x, 6.x] + - centos [5.x, 6.x] + """ + if client.is_windows(): + powershell_command = ('Remove-Item -Path (Join-Path -Path ' + '$($env:PSModulePath.Split(";") ' + '| Where-Object { $_.StartsWith(' + '$env:SystemRoot)}) -ChildPath ' + '"PoSh-Ohai") -Recurse -Force -ErrorAction ' + 'SilentlyContinue') + output = client.execute(powershell_command) + return output + else: + raise errors.PlatformNotSupported( + "PoSh-Ohai is a Windows-only sytem info provider. " + "Target platform was %s", client.platform_info['dist']) + + +def get_json(data): + """Find the JSON string in data and return a string. + + :param data: :string: + :returns: string -- JSON string stripped of non-JSON data + :raises: SystemInfoMissingJson + + SystemInfoMissingJson if no JSON is returned. + """ + try: + first = data.index('{') + last = data.rindex('}') + return data[first:last + 1] + except ValueError as exc: + context = {"ValueError": "%s" % exc} + raise errors.SystemInfoMissingJson(context) diff --git a/satori/tests/test_sysinfo_ohai_solo.py b/satori/tests/test_sysinfo_ohai_solo.py index 27a6fe6..fa7884e 100644 --- a/satori/tests/test_sysinfo_ohai_solo.py +++ b/satori/tests/test_sysinfo_ohai_solo.py @@ -48,119 +48,124 @@ class TestOhaiSolo(utils.TestCase): class TestOhaiInstall(utils.TestCase): - def test_install_remote_fedora(self): - mock_ssh = mock.MagicMock() - response = {'exit_code': 0, 'foo': 'bar'} - mock_ssh.execute.return_value = response - result = ohai_solo.install_remote(mock_ssh) - self.assertEqual(result, response) - self.assertEqual(mock_ssh.execute.call_count, 3) - mock_ssh.execute.assert_has_calls([ - mock.call('sudo wget -N http://ohai.rax.io/install.sh', wd='/tmp'), - mock.call('sudo bash install.sh', wd='/tmp', with_exit_code=True), - mock.call('sudo rm install.sh', wd='/tmp')]) + def setUp(self): + super(TestOhaiInstall, self).setUp() + self.mock_remotesshclient = mock.MagicMock() + self.mock_remotesshclient.is_windows.return_value = False - def test_install_remote_failed(self): - mock_ssh = mock.MagicMock() + def test_install_remote_fedora(self): + response = {'exit_code': 0, 'foo': 'bar'} + self.mock_remotesshclient.execute.return_value = response + result = ohai_solo.install_remote(self.mock_remotesshclient) + self.assertEqual(result, response) + self.assertEqual(self.mock_remotesshclient.execute.call_count, 3) + self.mock_remotesshclient.execute.assert_has_calls([ + mock.call('sudo wget -N http://ohai.rax.io/install.sh', cwd='/tmp'), + mock.call('sudo bash install.sh', cwd='/tmp', with_exit_code=True), + mock.call('sudo rm install.sh', cwd='/tmp')]) + + def test_install_linux_remote_failed(self): response = {'exit_code': 1, 'stdout': "", "stderr": "FAIL"} - mock_ssh.execute.return_value = response + self.mock_remotesshclient.execute.return_value = response self.assertRaises(errors.SystemInfoCommandInstallFailed, - ohai_solo.install_remote, mock_ssh) + ohai_solo.install_remote, self.mock_remotesshclient) class TestOhaiRemove(utils.TestCase): + def setUp(self): + super(TestOhaiRemove, self).setUp() + self.mock_remotesshclient = mock.MagicMock() + self.mock_remotesshclient.is_windows.return_value = False + def test_remove_remote_fedora(self): - mock_ssh = mock.MagicMock() - mock_ssh.is_debian.return_value = False - mock_ssh.is_fedora.return_value = True + self.mock_remotesshclient.is_debian.return_value = False + self.mock_remotesshclient.is_fedora.return_value = True response = {'exit_code': 0, 'foo': 'bar'} - mock_ssh.execute.return_value = response - result = ohai_solo.remove_remote(mock_ssh) + self.mock_remotesshclient.execute.return_value = response + result = ohai_solo.remove_remote(self.mock_remotesshclient) self.assertEqual(result, response) - mock_ssh.execute.assert_called_once_with( - 'sudo yum -y erase ohai-solo', wd='/tmp') + self.mock_remotesshclient.execute.assert_called_once_with( + 'sudo yum -y erase ohai-solo', cwd='/tmp') def test_remove_remote_debian(self): - mock_ssh = mock.MagicMock() - mock_ssh.is_debian.return_value = True - mock_ssh.is_fedora.return_value = False + self.mock_remotesshclient.is_debian.return_value = True + self.mock_remotesshclient.is_fedora.return_value = False response = {'exit_code': 0, 'foo': 'bar'} - mock_ssh.execute.return_value = response - result = ohai_solo.remove_remote(mock_ssh) + self.mock_remotesshclient.execute.return_value = response + result = ohai_solo.remove_remote(self.mock_remotesshclient) self.assertEqual(result, response) - mock_ssh.execute.assert_called_once_with( - 'sudo dpkg --purge ohai-solo', wd='/tmp') + self.mock_remotesshclient.execute.assert_called_once_with( + 'sudo dpkg --purge ohai-solo', cwd='/tmp') def test_remove_remote_unsupported(self): - mock_ssh = mock.MagicMock() - mock_ssh.is_debian.return_value = False - mock_ssh.is_fedora.return_value = False + self.mock_remotesshclient.is_debian.return_value = False + self.mock_remotesshclient.is_fedora.return_value = False self.assertRaises(errors.UnsupportedPlatform, - ohai_solo.remove_remote, mock_ssh) + ohai_solo.remove_remote, self.mock_remotesshclient) class TestSystemInfo(utils.TestCase): + def setUp(self): + super(TestSystemInfo, self).setUp() + self.mock_remotesshclient = mock.MagicMock() + self.mock_remotesshclient.is_windows.return_value = False + def test_system_info(self): - mock_ssh = mock.MagicMock() - mock_ssh.execute.return_value = { + self.mock_remotesshclient.execute.return_value = { 'exit_code': 0, 'stdout': "{}", 'stderr': "" } - ohai_solo.system_info(mock_ssh) - mock_ssh.execute.assert_called_with("sudo -i ohai-solo") + ohai_solo.system_info(self.mock_remotesshclient) + self.mock_remotesshclient.execute.assert_called_with( + "sudo -i ohai-solo") def test_system_info_with_motd(self): - mock_ssh = mock.MagicMock() - mock_ssh.execute.return_value = { + self.mock_remotesshclient.execute.return_value = { 'exit_code': 0, 'stdout': "Hello world\n {}", 'stderr': "" } - ohai_solo.system_info(mock_ssh) - mock_ssh.execute.assert_called_with("sudo -i ohai-solo") + ohai_solo.system_info(self.mock_remotesshclient) + self.mock_remotesshclient.execute.assert_called_with("sudo -i ohai-solo") def test_system_info_bad_json(self): - mock_ssh = mock.MagicMock() - mock_ssh.execute.return_value = { + self.mock_remotesshclient.execute.return_value = { 'exit_code': 0, 'stdout': "{Not JSON!}", 'stderr': "" } self.assertRaises(errors.SystemInfoNotJson, ohai_solo.system_info, - mock_ssh) + self.mock_remotesshclient) def test_system_info_missing_json(self): - mock_ssh = mock.MagicMock() - mock_ssh.execute.return_value = { + self.mock_remotesshclient.execute.return_value = { 'exit_code': 0, 'stdout': "No JSON!", 'stderr': "" } self.assertRaises(errors.SystemInfoMissingJson, ohai_solo.system_info, - mock_ssh) + self.mock_remotesshclient) def test_system_info_command_not_found(self): - mock_ssh = mock.MagicMock() - mock_ssh.execute.return_value = { + self.mock_remotesshclient.execute.return_value = { 'exit_code': 1, 'stdout': "", 'stderr': "ohai-solo command not found" } self.assertRaises(errors.SystemInfoCommandMissing, - ohai_solo.system_info, mock_ssh) + ohai_solo.system_info, self.mock_remotesshclient) def test_system_info_could_not_find(self): - mock_ssh = mock.MagicMock() - mock_ssh.execute.return_value = { + self.mock_remotesshclient.execute.return_value = { 'exit_code': 1, 'stdout': "", 'stderr': "Could not find ohai-solo." } self.assertRaises(errors.SystemInfoCommandMissing, - ohai_solo.system_info, mock_ssh) + ohai_solo.system_info, self.mock_remotesshclient) if __name__ == "__main__": diff --git a/satori/tunnel.py b/satori/tunnel.py new file mode 100644 index 0000000..e40d583 --- /dev/null +++ b/satori/tunnel.py @@ -0,0 +1,167 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +"""SSH tunneling module. + +Set up a forward tunnel across an SSH server, using paramiko. A local port +(given with -p) is forwarded across an SSH session to an address:port from +the SSH server. This is similar to the openssh -L option. +""" +try: + import eventlet + eventlet.monkey_patch() + from eventlet.green import threading + from eventlet.green import time +except ImportError: + import threading + import time + pass + +import logging +import select +import socket +try: + import SocketServer +except ImportError: + import socketserver as SocketServer + +import paramiko + +LOG = logging.getLogger(__name__) + + +class TunnelServer(SocketServer.ThreadingTCPServer): + + """Serve on a local ephemeral port. + + Clients will connect to that port/server. + """ + + daemon_threads = True + allow_reuse_address = True + + +class TunnelHandler(SocketServer.BaseRequestHandler): + + """Handle forwarding of packets.""" + + def handle(self): + """Do all the work required to service a request. + + The request is available as self.request, the client address as + self.client_address, and the server instance as self.server, in + case it needs to access per-server information. + + This implementation will forward packets. + """ + try: + chan = self.ssh_transport.open_channel('direct-tcpip', + self.target_address, + self.request.getpeername()) + except Exception as exc: + LOG.error('Incoming request to %s:%s failed', + self.target_address[0], + self.target_address[1], + exc_info=exc) + return + if chan is None: + LOG.error('Incoming request to %s:%s was rejected ' + 'by the SSH server.', + self.target_address[0], + self.target_address[1]) + return + + while True: + r, w, x = select.select([self.request, chan], [], []) + if self.request in r: + data = self.request.recv(1024) + if len(data) == 0: + break + chan.send(data) + if chan in r: + data = chan.recv(1024) + if len(data) == 0: + break + self.request.send(data) + + try: + peername = None + peername = str(self.request.getpeername()) + except socket.error as exc: + LOG.warning("Couldn't fetch peername.", exc_info=exc) + chan.close() + self.request.close() + LOG.info("Tunnel closed from '%s'", peername or 'unnamed peer') + + +class Tunnel(object): # pylint: disable=R0902 + + """Create a TCP server which will use TunnelHandler.""" + + def __init__(self, target_host, target_port, + sshclient, tunnel_host='localhost', + tunnel_port=0): + """Constructor.""" + if not isinstance(sshclient, paramiko.SSHClient): + raise TypeError("'sshclient' must be an instance of " + "paramiko.SSHClient.") + + self.target_host = target_host + self.target_port = target_port + self.target_address = (target_host, target_port) + self.address = (tunnel_host, tunnel_port) + + self._tunnel = None + self._tunnel_thread = None + self.sshclient = sshclient + self._ssh_transport = self.get_sshclient_transport( + self.sshclient) + + TunnelHandler.target_address = self.target_address + TunnelHandler.ssh_transport = self._ssh_transport + + self._tunnel = TunnelServer(self.address, TunnelHandler) + # reset attribute to the port it has actually been set to + self.address = self._tunnel.server_address + tunnel_host, self.tunnel_port = self.address + + def get_sshclient_transport(self, sshclient): + """Get the sshclient's transport. + + Connect the sshclient, that has been passed in and return its + transport. + """ + sshclient.connect() + return sshclient.get_transport() + + def serve_forever(self, async=True): + """Serve the tunnel forever. + + if async is True, this will be done in a background thread + """ + if not async: + self._tunnel.serve_forever() + else: + self._tunnel_thread = threading.Thread( + target=self._tunnel.serve_forever) + self.start() + # cooperative yield + time.sleep(0) + + def shutdown(self): + """Stop serving the tunnel. + + Also close the socket. + """ + self._tunnel.shutdown() + self._tunnel.socket.close()