Update shell connection tools

- Allows to get a connection by its host name
- 'localhost' and local hostname will return a local connection
- add method to expand paths
- add method to get local and remote environment variables

Change-Id: I21490a9dc889e4be4256ea406bd1707b7bdd3ae5
This commit is contained in:
Federico Ressi 2022-07-19 09:14:02 +02:00
parent a5e5f0386f
commit d9513fe77b

View File

@ -37,7 +37,7 @@ from tobiko.shell import ssh
LOG = log.getLogger(__name__)
ShellConnectionType = typing.Union['ShellConnection', ssh.SSHClientType]
ShellConnectionType = typing.Union['ShellConnection', ssh.SSHClientType, str]
def connection_hostname(connection: ShellConnectionType = None) -> str:
@ -135,8 +135,7 @@ def shell_connection(obj: ShellConnectionType = None,
if isinstance(obj, ShellConnection):
return obj
else:
return shell_connection_manager(manager).get_shell_connection(
ssh_client=obj)
return shell_connection_manager(manager).get_shell_connection(obj)
def ssh_shell_connection(obj: ShellConnectionType = None,
@ -166,14 +165,34 @@ ShellConnectionKey = typing.Optional[ssh.SSHClientFixture]
class ShellConnectionManager(tobiko.SharedFixture):
def __init__(self):
def __init__(self,
local_hostnames: typing.Iterable[str] = None):
super(ShellConnectionManager, self).__init__()
self._host_connections: typing.Dict['ShellConnectionKey',
'ShellConnection'] = {}
if local_hostnames is not None:
local_hostnames = set(local_hostnames)
self._local_hostnames = local_hostnames
def get_shell_connection(self, ssh_client: ssh.SSHClientType) -> \
@property
def local_hostnames(self) -> typing.Set[str]:
if self._local_hostnames is None:
hostname = socket.gethostname()
self._local_hostnames = {'localhost',
tobiko.get_short_hostname(hostname)}
return self._local_hostnames
def get_shell_connection(self,
obj: typing.Union[str, ssh.SSHClientType]) -> \
'ShellConnection':
ssh_client = ssh.ssh_client_fixture(ssh_client)
ssh_client: typing.Optional[ssh.SSHClientFixture]
if isinstance(obj, str):
if obj in self.local_hostnames:
ssh_client = None # local connection
else:
ssh_client = ssh.ssh_client(host=obj)
else:
ssh_client = ssh.ssh_client_fixture(obj)
connection = self._host_connections.get(ssh_client)
if connection is None:
connection = self._get_shell_connection(ssh_client=ssh_client)
@ -228,6 +247,9 @@ class ShellConnection(tobiko.SharedFixture):
def login(self) -> str:
return f"{self.username}@{self.hostname}"
def get_config_path(self, path: str) -> str:
raise NotImplementedError
def execute(self,
command: _command.ShellCommandType,
*args, **execute_params) -> \
@ -272,6 +294,9 @@ class ShellConnection(tobiko.SharedFixture):
def get_file(self, remote_file: str, local_file: str):
raise NotImplementedError
def get_environ(self) -> typing.Dict[str, str]:
raise NotImplementedError
def open_file(self,
filename: typing.Union[str, bytes],
mode: str,
@ -322,6 +347,9 @@ class LocalShellConnection(ShellConnection):
def hostname(self) -> str:
return socket.gethostname()
def get_environ(self) -> typing.Dict[str, str]:
return dict(os.environ)
def put_file(self, local_file: str, remote_file: str):
LOG.debug(f"Copy local file as {self.login}: '{local_file}' -> "
f"'{remote_file}' ...")
@ -386,6 +414,9 @@ class LocalShellConnection(ShellConnection):
os.makedirs(name=name,
exist_ok=exist_ok)
def get_config_path(self, path: str) -> str:
return tobiko.tobiko_config_path(path)
class SSHShellConnection(ShellConnection):
@ -430,6 +461,11 @@ class SSHShellConnection(ShellConnection):
self._sftp = self.ssh_client.connect().open_sftp()
return self._sftp
def get_environ(self) -> typing.Dict[str, str]:
return dict(_parse_env_line(line)
for line in self.execute('env').stdout().splitlines()
if line.strip())
def put_file(self, local_file: str, remote_file: str):
LOG.debug(f"Put remote file as {self.login}: '{local_file}' -> "
f"'{remote_file}'...")
@ -506,3 +542,21 @@ class SSHShellConnection(ShellConnection):
command += '-p'
command += name
self.execute(command)
_user_dir: typing.Optional[str] = None
@property
def user_dir(self) -> str:
if self._user_dir is None:
self._user_dir = self.execute('sh -c "echo $HOME"').stdout.strip()
return self._user_dir
def get_config_path(self, path: str) -> str:
if path[0] in '~.':
path = f"{self.user_dir}{path[1:]}"
return path
def _parse_env_line(line: str) -> typing.Tuple[str, str]:
name, value = line.split('=', 1)
return name.strip(), value.strip()