Update shell connection tools
- Allows to get a connection by its host name - 'localhost' and local hostname will return a local connection - add method to expand paths - add method to get local and remote environment variables Change-Id: I21490a9dc889e4be4256ea406bd1707b7bdd3ae5
This commit is contained in:
parent
a5e5f0386f
commit
d9513fe77b
@ -37,7 +37,7 @@ from tobiko.shell import ssh
|
||||
|
||||
LOG = log.getLogger(__name__)
|
||||
|
||||
ShellConnectionType = typing.Union['ShellConnection', ssh.SSHClientType]
|
||||
ShellConnectionType = typing.Union['ShellConnection', ssh.SSHClientType, str]
|
||||
|
||||
|
||||
def connection_hostname(connection: ShellConnectionType = None) -> str:
|
||||
@ -135,8 +135,7 @@ def shell_connection(obj: ShellConnectionType = None,
|
||||
if isinstance(obj, ShellConnection):
|
||||
return obj
|
||||
else:
|
||||
return shell_connection_manager(manager).get_shell_connection(
|
||||
ssh_client=obj)
|
||||
return shell_connection_manager(manager).get_shell_connection(obj)
|
||||
|
||||
|
||||
def ssh_shell_connection(obj: ShellConnectionType = None,
|
||||
@ -166,14 +165,34 @@ ShellConnectionKey = typing.Optional[ssh.SSHClientFixture]
|
||||
|
||||
class ShellConnectionManager(tobiko.SharedFixture):
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self,
|
||||
local_hostnames: typing.Iterable[str] = None):
|
||||
super(ShellConnectionManager, self).__init__()
|
||||
self._host_connections: typing.Dict['ShellConnectionKey',
|
||||
'ShellConnection'] = {}
|
||||
if local_hostnames is not None:
|
||||
local_hostnames = set(local_hostnames)
|
||||
self._local_hostnames = local_hostnames
|
||||
|
||||
def get_shell_connection(self, ssh_client: ssh.SSHClientType) -> \
|
||||
@property
|
||||
def local_hostnames(self) -> typing.Set[str]:
|
||||
if self._local_hostnames is None:
|
||||
hostname = socket.gethostname()
|
||||
self._local_hostnames = {'localhost',
|
||||
tobiko.get_short_hostname(hostname)}
|
||||
return self._local_hostnames
|
||||
|
||||
def get_shell_connection(self,
|
||||
obj: typing.Union[str, ssh.SSHClientType]) -> \
|
||||
'ShellConnection':
|
||||
ssh_client = ssh.ssh_client_fixture(ssh_client)
|
||||
ssh_client: typing.Optional[ssh.SSHClientFixture]
|
||||
if isinstance(obj, str):
|
||||
if obj in self.local_hostnames:
|
||||
ssh_client = None # local connection
|
||||
else:
|
||||
ssh_client = ssh.ssh_client(host=obj)
|
||||
else:
|
||||
ssh_client = ssh.ssh_client_fixture(obj)
|
||||
connection = self._host_connections.get(ssh_client)
|
||||
if connection is None:
|
||||
connection = self._get_shell_connection(ssh_client=ssh_client)
|
||||
@ -228,6 +247,9 @@ class ShellConnection(tobiko.SharedFixture):
|
||||
def login(self) -> str:
|
||||
return f"{self.username}@{self.hostname}"
|
||||
|
||||
def get_config_path(self, path: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def execute(self,
|
||||
command: _command.ShellCommandType,
|
||||
*args, **execute_params) -> \
|
||||
@ -272,6 +294,9 @@ class ShellConnection(tobiko.SharedFixture):
|
||||
def get_file(self, remote_file: str, local_file: str):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_environ(self) -> typing.Dict[str, str]:
|
||||
raise NotImplementedError
|
||||
|
||||
def open_file(self,
|
||||
filename: typing.Union[str, bytes],
|
||||
mode: str,
|
||||
@ -322,6 +347,9 @@ class LocalShellConnection(ShellConnection):
|
||||
def hostname(self) -> str:
|
||||
return socket.gethostname()
|
||||
|
||||
def get_environ(self) -> typing.Dict[str, str]:
|
||||
return dict(os.environ)
|
||||
|
||||
def put_file(self, local_file: str, remote_file: str):
|
||||
LOG.debug(f"Copy local file as {self.login}: '{local_file}' -> "
|
||||
f"'{remote_file}' ...")
|
||||
@ -386,6 +414,9 @@ class LocalShellConnection(ShellConnection):
|
||||
os.makedirs(name=name,
|
||||
exist_ok=exist_ok)
|
||||
|
||||
def get_config_path(self, path: str) -> str:
|
||||
return tobiko.tobiko_config_path(path)
|
||||
|
||||
|
||||
class SSHShellConnection(ShellConnection):
|
||||
|
||||
@ -430,6 +461,11 @@ class SSHShellConnection(ShellConnection):
|
||||
self._sftp = self.ssh_client.connect().open_sftp()
|
||||
return self._sftp
|
||||
|
||||
def get_environ(self) -> typing.Dict[str, str]:
|
||||
return dict(_parse_env_line(line)
|
||||
for line in self.execute('env').stdout().splitlines()
|
||||
if line.strip())
|
||||
|
||||
def put_file(self, local_file: str, remote_file: str):
|
||||
LOG.debug(f"Put remote file as {self.login}: '{local_file}' -> "
|
||||
f"'{remote_file}'...")
|
||||
@ -506,3 +542,21 @@ class SSHShellConnection(ShellConnection):
|
||||
command += '-p'
|
||||
command += name
|
||||
self.execute(command)
|
||||
|
||||
_user_dir: typing.Optional[str] = None
|
||||
|
||||
@property
|
||||
def user_dir(self) -> str:
|
||||
if self._user_dir is None:
|
||||
self._user_dir = self.execute('sh -c "echo $HOME"').stdout.strip()
|
||||
return self._user_dir
|
||||
|
||||
def get_config_path(self, path: str) -> str:
|
||||
if path[0] in '~.':
|
||||
path = f"{self.user_dir}{path[1:]}"
|
||||
return path
|
||||
|
||||
|
||||
def _parse_env_line(line: str) -> typing.Tuple[str, str]:
|
||||
name, value = line.split('=', 1)
|
||||
return name.strip(), value.strip()
|
||||
|
Loading…
x
Reference in New Issue
Block a user