
New implementation adds verification of created port forwarding. Also it produce detailed error reporting including verbose output from ssh process. Change-Id: I8841e89cf7e78fad0f99e05d1d83a73cc253d05b Closes-Bug: #1644835
304 lines
9.7 KiB
Python
304 lines
9.7 KiB
Python
#
|
|
# Copyright 2016 Cray Inc., All Rights Reserved
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
|
# not use this file except in compliance with the License. You may obtain
|
|
# a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
# License for the specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
import abc
|
|
import collections
|
|
import fcntl
|
|
import os
|
|
import select
|
|
import socket
|
|
import subprocess
|
|
import tempfile
|
|
import textwrap
|
|
import time
|
|
|
|
import six
|
|
from ironic.common import exception
|
|
|
|
from bareon_ironic import exception as exc
|
|
from bareon_ironic.common import subproc_utils
|
|
|
|
NetAddr = collections.namedtuple('NetAddr', ('host', 'port'))
|
|
|
|
|
|
@six.add_metaclass(abc.ABCMeta)
|
|
class _SSHPortForwardingAbstract(object):
|
|
"""Abstract base class for SSH port forwarding handlers
|
|
|
|
Local and remote port forwarding management handlers are based on this
|
|
abstract class. It implement context management protocol so it must be used
|
|
inside "with" keyword. Port forwarding is activated into __enter__ method
|
|
and deactivated into __exit__ method.
|
|
"""
|
|
|
|
proc = None
|
|
subprocess_output = None
|
|
|
|
def __init__(self, user, key_file, host, bind, forward,
|
|
ssh_port=None, validate_timeout=30):
|
|
"""Collect and store data required for port forwarding setup
|
|
|
|
:param user: username for SSH connections
|
|
:type user: str
|
|
:param key_file: path to private key for SSH authentication
|
|
:type key_file: str
|
|
:param host: address of remote hosts
|
|
:type host: str
|
|
:param bind: address/port pair to bind forwarded port
|
|
:type bind: NetAddr
|
|
:param forward: address/port pair defining destination of port
|
|
forwarding
|
|
:type forward: NetAddr
|
|
:param ssh_port: SSH port on remote host
|
|
:type ssh_port: int
|
|
:param validate_timeout: amount of time we will wait for creation of
|
|
port forwarding. If set to zero - skip
|
|
validation.
|
|
:type validate_timeout: int
|
|
"""
|
|
self.user = user
|
|
self.key = key_file
|
|
self.host = host
|
|
self.bind = bind
|
|
self.forward = forward
|
|
self.ssh_port = ssh_port
|
|
self.validate_timeout = max(validate_timeout, 0)
|
|
|
|
def __repr__(self):
|
|
setup = self._ssh_command_add_forward_arguments([])
|
|
setup = ' '.join(setup)
|
|
return '<{} {}>'.format(type(self).__name__, setup)
|
|
|
|
def __enter__(self):
|
|
self.subprocess_output = tempfile.TemporaryFile()
|
|
|
|
cmd = self._make_ssh_command()
|
|
proc_args = {
|
|
'close_fds': True,
|
|
'stderr': self.subprocess_output}
|
|
proc_args = self._subprocess_add_args(proc_args)
|
|
try:
|
|
self.proc = subprocess.Popen(cmd, **proc_args)
|
|
if self.validate_timeout:
|
|
self._check_port_forwarding()
|
|
except OSError as e:
|
|
raise exc.SubprocessError(command=cmd, error=e)
|
|
except exc.SSHForwardedPortValidationError as e:
|
|
self._kill()
|
|
remote = self._ssh_command_add_destination([])
|
|
remote = ' '.join(str(x) for x in remote)
|
|
raise exc.SSHSetupForwardingError(
|
|
forward=self, error=e.message, remote=remote,
|
|
output=self._grab_subprocess_output())
|
|
except Exception:
|
|
self._kill()
|
|
raise
|
|
|
|
return self
|
|
|
|
def __exit__(self, *exc_info):
|
|
del exc_info
|
|
if self.proc.poll() is not None:
|
|
raise exception.SSHConnectFailed(
|
|
'Unexpected SSH process termination.\n'
|
|
'{}'.format(self._grab_subprocess_output()))
|
|
self._kill()
|
|
|
|
def _make_ssh_command(self):
|
|
cmd = ['ssh']
|
|
cmd = self._ssh_command_add_auth_arguments(cmd)
|
|
cmd = self._ssh_command_add_forward_arguments(cmd)
|
|
cmd = self._ssh_command_add_extra_arguments(cmd)
|
|
cmd = self._ssh_command_add_destination(cmd)
|
|
cmd = self._ssh_command_add_command(cmd)
|
|
|
|
return cmd
|
|
|
|
def _ssh_command_add_auth_arguments(self, cmd):
|
|
return cmd + ['-i', self.key]
|
|
|
|
@abc.abstractmethod
|
|
def _ssh_command_add_forward_arguments(self, cmd):
|
|
return cmd
|
|
|
|
def _ssh_command_add_extra_arguments(self, cmd):
|
|
return cmd + [
|
|
'-v',
|
|
'-o', 'BatchMode=yes',
|
|
'-o', 'RequestTTY=no',
|
|
'-o', 'ExitOnForwardFailure=yes',
|
|
'-o', 'StrictHostKeyChecking=no',
|
|
'-o', 'UserKnownHostsFile=/dev/null']
|
|
|
|
def _ssh_command_add_destination(self, cmd):
|
|
if self.ssh_port is not None:
|
|
cmd += ['-p', str(self.ssh_port)]
|
|
return cmd + ['@'.join((self.user, self.host))]
|
|
|
|
def _ssh_command_add_command(self, cmd):
|
|
return cmd
|
|
|
|
def _subprocess_add_args(self, args):
|
|
return args
|
|
|
|
@abc.abstractmethod
|
|
def _check_port_forwarding(self):
|
|
pass
|
|
|
|
def _kill(self):
|
|
terminator = subproc_utils.ProcTerminator(self.proc)
|
|
terminator()
|
|
|
|
def _grab_subprocess_output(self):
|
|
output = 'There is no output from SSH process'
|
|
if self.subprocess_output.tell():
|
|
self.subprocess_output.seek(os.SEEK_SET)
|
|
output = self.subprocess_output.read()
|
|
return output
|
|
|
|
|
|
class SSHLocalPortForwarding(_SSHPortForwardingAbstract):
|
|
def _ssh_command_add_forward_arguments(self, cmd):
|
|
forward = self.bind + self.forward
|
|
forward = [str(x) for x in forward]
|
|
return cmd + ['-L', ':'.join(forward)]
|
|
|
|
def _ssh_command_add_extra_arguments(self, cmd):
|
|
parent = super(SSHLocalPortForwarding, self)
|
|
return parent._ssh_command_add_extra_arguments(cmd) + ['-N']
|
|
|
|
def _check_port_forwarding(self):
|
|
now = time.time()
|
|
time_end = now + self.validate_timeout
|
|
while now < time_end and self.proc.poll() is None:
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
try:
|
|
s.connect(self.bind)
|
|
except socket.error:
|
|
time.sleep(1)
|
|
now = time.time()
|
|
continue
|
|
finally:
|
|
s.close()
|
|
|
|
break
|
|
else:
|
|
raise exc.SSHForwardedPortValidationError
|
|
|
|
|
|
class SSHRemotePortForwarding(_SSHPortForwardingAbstract):
|
|
def _ssh_command_add_forward_arguments(self, cmd):
|
|
forward = self.bind + self.forward
|
|
forward = [str(x) for x in forward]
|
|
return cmd + ['-R', ':'.join(forward)]
|
|
|
|
def _ssh_command_add_extra_arguments(self, cmd):
|
|
parent = super(SSHRemotePortForwarding, self)
|
|
cmd = parent._ssh_command_add_extra_arguments(cmd)
|
|
if not self.validate_timeout:
|
|
cmd += ['-N']
|
|
return cmd
|
|
|
|
def _ssh_command_add_command(self, cmd):
|
|
parent = super(SSHRemotePortForwarding, self)
|
|
cmd = parent._ssh_command_add_command(cmd)
|
|
if self.validate_timeout:
|
|
cmd += ['python']
|
|
return cmd
|
|
|
|
def _subprocess_add_args(self, args):
|
|
args = super(SSHRemotePortForwarding, self)._subprocess_add_args(args)
|
|
if self.validate_timeout:
|
|
args['stdin'] = subprocess.PIPE
|
|
args['stdout'] = subprocess.PIPE
|
|
return args
|
|
|
|
def _check_port_forwarding(self):
|
|
fail_marker = 'CONNECT FAILED'
|
|
success_marker = 'CONNECT APPROVED'
|
|
validate_snippet = textwrap.dedent("""
|
|
import os
|
|
import socket
|
|
import sys
|
|
import time
|
|
|
|
timeout = max({timeout}, 1)
|
|
addr = '{address.host}'
|
|
port = {address.port}
|
|
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
try:
|
|
now = time.time()
|
|
time_end = now + timeout
|
|
|
|
while now < time_end:
|
|
try:
|
|
s.connect((addr, port))
|
|
except socket.error:
|
|
time.sleep(1)
|
|
now = time.time()
|
|
continue
|
|
break
|
|
else:
|
|
sys.stderr.write('{fail_marker}')
|
|
sys.stderr.write(os.linesep)
|
|
sys.exit(1)
|
|
|
|
sys.stdout.write('{success_marker}')
|
|
sys.stdout.write(os.linesep)
|
|
sys.stdout.close()
|
|
finally:
|
|
s.close()
|
|
|
|
while True:
|
|
time.sleep(128)
|
|
""").lstrip().format(
|
|
address=self.bind, timeout=self.validate_timeout,
|
|
success_marker=success_marker, fail_marker=fail_marker)
|
|
|
|
self.proc.stdin.write(validate_snippet)
|
|
self.proc.stdin.close()
|
|
|
|
stdout = self.proc.stdout.fileno()
|
|
opts = fcntl.fcntl(stdout, fcntl.F_GETFL)
|
|
fcntl.fcntl(stdout, fcntl.F_SETFL, opts | os.O_NONBLOCK)
|
|
|
|
now = time.time()
|
|
time_end = now + self.validate_timeout
|
|
output = []
|
|
while now < time_end:
|
|
ready = select.select([stdout], [], [], time_end - now)
|
|
ready_read = ready[0]
|
|
if ready_read:
|
|
chunk = self.proc.stdout.read(512)
|
|
if not chunk:
|
|
break
|
|
output.append(chunk)
|
|
|
|
if success_marker in ''.join(output):
|
|
break
|
|
|
|
now = time.time()
|
|
else:
|
|
raise exc.SSHForwardedPortValidationError
|
|
|
|
|
|
def forward_remote_port(
|
|
port, user, key_file, target_host, ssh_port=None):
|
|
return SSHRemotePortForwarding(
|
|
user, key_file, target_host,
|
|
NetAddr('127.0.0.1', port), NetAddr('127.0.0.1', port),
|
|
ssh_port=ssh_port)
|