diff --git a/changelog.md b/changelog.md index d8b3c4d6..bdf251ca 100644 --- a/changelog.md +++ b/changelog.md @@ -3,9 +3,10 @@ TBD Features: --------- -* Add an option `--ssh-config-host` to read ssh configuration from OpenSSH configuration file. -* Add an option `--list-ssh-config` to list ssh configurations. -* Add an option `--ssh-config-path` to choose ssh configuration path. +* Add an option `--ssh-config-host` to read ssh configuration from OpenSSH configuration file (Thanks: [Nathan Huang]). +* Add an option `--list-ssh-config` to list ssh configurations (Thanks: [Nathan Huang]). +* Add an option `--ssh-config-path` to choose ssh configuration path (Thanks: [Nathan Huang]). +* Reuse the same SSH connection in both main thread and completion thread (Thanks: [Georgy Frolov]). 1.21.1 @@ -757,3 +758,4 @@ Bug Fixes: [François Pietka]: https://github.com/fpietka [Frederic Aoustin]: https://github.com/fraoustin [Georgy Frolov]: https://github.com/pasenor +[Nathan Huang]: https://github.com/hxueh diff --git a/mycli/completion_refresher.py b/mycli/completion_refresher.py index e6c8dd07..29c8da5f 100644 --- a/mycli/completion_refresher.py +++ b/mycli/completion_refresher.py @@ -51,8 +51,7 @@ def _bg_refresh(self, sqlexecute, callbacks, completer_options): e = sqlexecute executor = SQLExecute(e.dbname, e.user, e.password, e.host, e.port, e.socket, e.charset, e.local_infile, e.ssl, - e.ssh_user, e.ssh_host, e.ssh_port, - e.ssh_password, e.ssh_key_filename) + e.ssh_client) # If callbacks is a single function then push it into a list. if callable(callbacks): diff --git a/mycli/main.py b/mycli/main.py index d298f202..22817ab7 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -31,10 +31,11 @@ from prompt_toolkit.history import FileHistory from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from mycli.packages.ssh_client import create_ssh_client from .packages.special.main import NO_QUERY from .packages.prompt_utils import confirm, confirm_destructive_query from .packages.tabular_output import sql_format -from .packages import special +from .packages import special, ssh_client from .packages.special.favoritequeries import FavoriteQueries from .sqlcompleter import SQLCompleter from .clitoolbar import create_toolbar_tokens_func @@ -63,11 +64,6 @@ from urllib.parse import unquote -try: - import paramiko -except ImportError: - from mycli.packages.paramiko_stub import paramiko - # Query tuples are used for maintaining history Query = namedtuple('Query', ['query', 'successful', 'mutating']) @@ -198,6 +194,8 @@ def __init__(self, sqlexecute=None, prompt=None, self.prompt_app = None + self.ssh_client = None + def register_special_commands(self): special.register_special_command(self.change_db, 'use', '\\u', 'Change to a new database.', aliases=('\\u',)) @@ -358,9 +356,7 @@ def merge_ssl_with_cnf(self, ssl, cnf): return merged def connect(self, database='', user='', passwd='', host='', port='', - socket='', charset='', local_infile='', ssl='', - ssh_user='', ssh_host='', ssh_port='', - ssh_password='', ssh_key_filename=''): + socket='', charset='', local_infile='', ssl=None): cnf = {'database': None, 'user': None, @@ -384,7 +380,7 @@ def connect(self, database='', user='', passwd='', host='', port='', database = database or cnf['database'] # Socket interface not supported for SSH connections - if port or host or ssh_host or ssh_port: + if port or host or self.ssh_client: socket = '' else: socket = socket or cnf['socket'] or guess_socket_location() @@ -416,8 +412,7 @@ def _connect(): try: self.sqlexecute = SQLExecute( database, user, passwd, host, port, socket, charset, - local_infile, ssl, ssh_user, ssh_host, ssh_port, - ssh_password, ssh_key_filename + local_infile, ssl, ssh_client=self.ssh_client ) except OperationalError as e: if ('Access denied for user' in e.args[1]): @@ -425,8 +420,7 @@ def _connect(): show_default=False, type=str, err=True) self.sqlexecute = SQLExecute( database, user, new_passwd, host, port, socket, - charset, local_infile, ssl, ssh_user, ssh_host, - ssh_port, ssh_password, ssh_key_filename + charset, local_infile, ssl, ssh_client=self.ssh_client ) else: raise e @@ -1092,16 +1086,17 @@ def cli(database, user, host, port, socket, password, dbname, else: click.secho(alias) sys.exit(0) + if list_ssh_config: - ssh_config = read_ssh_config(ssh_config_path) - for host in ssh_config.get_hostnames(): + hosts = ssh_client.get_config_hosts(ssh_config_path) + for host, hostname in hosts.items(): if verbose: - host_config = ssh_config.lookup(host) click.secho("{} : {}".format( - host, host_config.get('hostname'))) + host, hostname)) else: click.secho(host) sys.exit(0) + # Choose which ever one has a valid value. database = dbname or database @@ -1153,7 +1148,7 @@ def cli(database, user, host, port, socket, password, dbname, port = uri.port if ssh_config_host: - ssh_config = read_ssh_config( + ssh_config = ssh_client.read_config_file( ssh_config_path ).lookup(ssh_config_host) ssh_host = ssh_host if ssh_host else ssh_config.get('hostname') @@ -1164,7 +1159,10 @@ def cli(database, user, host, port, socket, password, dbname, ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config.get( 'identityfile', [None])[0] - ssh_key_filename = ssh_key_filename and os.path.expanduser(ssh_key_filename) + if ssh_host: + mycli.ssh_client = create_ssh_client( + ssh_host, ssh_port, ssh_user, ssh_password, ssh_key_filename + ) mycli.connect( database=database, @@ -1175,11 +1173,6 @@ def cli(database, user, host, port, socket, password, dbname, socket=socket, local_infile=local_infile, ssl=ssl, - ssh_user=ssh_user, - ssh_host=ssh_host, - ssh_port=ssh_port, - ssh_password=ssh_password, - ssh_key_filename=ssh_key_filename ) mycli.logger.debug('Launch Params: \n' @@ -1298,26 +1291,5 @@ def edit_and_execute(event): buff.open_in_editor(validate_and_handle=False) -def read_ssh_config(ssh_config_path): - ssh_config = paramiko.config.SSHConfig() - try: - with open(ssh_config_path) as f: - ssh_config.parse(f) - # Paramiko prior to version 2.7 raises Exception on parse errors. - # In 2.7 it has become paramiko.ssh_exception.SSHException, - # but let's catch everything for compatibility - except Exception as err: - click.secho( - f'Could not parse SSH configuration file {ssh_config_path}:\n{err} ', - err=True, fg='red' - ) - sys.exit(1) - except FileNotFoundError as e: - click.secho(str(e), err=True, fg='red') - sys.exit(1) - else: - return ssh_config - - if __name__ == "__main__": cli() diff --git a/mycli/packages/ssh_client/__init__.py b/mycli/packages/ssh_client/__init__.py new file mode 100644 index 00000000..216b8db1 --- /dev/null +++ b/mycli/packages/ssh_client/__init__.py @@ -0,0 +1 @@ +from .client import get_config_hosts, create_ssh_client, SSHException, read_config_file diff --git a/mycli/packages/ssh_client/client.py b/mycli/packages/ssh_client/client.py new file mode 100644 index 00000000..3abfd018 --- /dev/null +++ b/mycli/packages/ssh_client/client.py @@ -0,0 +1,47 @@ +""" +A very thin wrapper around paramiko, mostly to keep all SSH-related functionality in one place +""" +from io import open + +try: + import paramiko +except ImportError: + from mycli.packages.paramiko_stub import paramiko + + +class SSHException(Exception): + pass + + +def get_config_hosts(config_path): + config = read_config_file(config_path) + return { + host: config.lookup(host).get("hostname") for host in config.get_hostnames() + } + + +def create_ssh_client(ssh_host, ssh_port, ssh_user, ssh_password=None, ssh_key_filename=None) -> paramiko.SSHClient: + client = paramiko.SSHClient() + client.load_system_host_keys() + client.set_missing_host_key_policy(paramiko.WarningPolicy()) + client.connect( + ssh_host, ssh_port, ssh_user, password=ssh_password, key_filename=ssh_key_filename + ) + return client + + +def read_config_file(config_path) -> paramiko.SSHConfig: + ssh_config = paramiko.config.SSHConfig() + try: + with open(config_path) as f: + ssh_config.parse(f) + # Paramiko prior to version 2.7 raises Exception on parse errors. + # In 2.7 it has become paramiko.ssh_exception.SSHException, + # but let's catch everything for compatibility + except Exception as err: + raise SSHException( + f"Could not parse SSH configuration file {config_path}:\n{err} ", + ) + except FileNotFoundError as e: + raise SSHException(str(e)) + return ssh_config diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index 035d98d1..fe229695 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -6,10 +6,7 @@ from pymysql.converters import (convert_mysql_timestamp, convert_datetime, convert_timedelta, convert_date, conversions, decoders) -try: - import paramiko -except ImportError: - from mycli.packages.paramiko_stub import paramiko + _logger = logging.getLogger(__name__) @@ -18,6 +15,7 @@ FIELD_TYPE.NULL: type(None) }) + class SQLExecute(object): databases_query = '''SHOW DATABASES''' @@ -41,8 +39,8 @@ class SQLExecute(object): order by table_name,ordinal_position''' def __init__(self, database, user, password, host, port, socket, charset, - local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password, - ssh_key_filename): + local_infile, ssl, + ssh_client=None): self.dbname = database self.user = user self.password = password @@ -54,17 +52,12 @@ def __init__(self, database, user, password, host, port, socket, charset, self.ssl = ssl self._server_type = None self.connection_id = None - self.ssh_user = ssh_user - self.ssh_host = ssh_host - self.ssh_port = ssh_port - self.ssh_password = ssh_password - self.ssh_key_filename = ssh_key_filename + self.ssh_client = ssh_client + self.connect() def connect(self, database=None, user=None, password=None, host=None, - port=None, socket=None, charset=None, local_infile=None, - ssl=None, ssh_host=None, ssh_port=None, ssh_user=None, - ssh_password=None, ssh_key_filename=None): + port=None, socket=None, charset=None, local_infile=None, ssl=None): db = (database or self.dbname) user = (user or self.user) password = (password or self.password) @@ -74,11 +67,6 @@ def connect(self, database=None, user=None, password=None, host=None, charset = (charset or self.charset) local_infile = (local_infile or self.local_infile) ssl = (ssl or self.ssl) - ssh_user = (ssh_user or self.ssh_user) - ssh_host = (ssh_host or self.ssh_host) - ssh_port = (ssh_port or self.ssh_port) - ssh_password = (ssh_password or self.ssh_password) - ssh_key_filename = (ssh_key_filename or self.ssh_key_filename) _logger.debug( 'Connection DB Params: \n' '\tdatabase: %r' @@ -88,14 +76,8 @@ def connect(self, database=None, user=None, password=None, host=None, '\tsocket: %r' '\tcharset: %r' '\tlocal_infile: %r' - '\tssl: %r' - '\tssh_user: %r' - '\tssh_host: %r' - '\tssh_port: %r' - '\tssh_password: %r' - '\tssh_key_filename: %r', + '\tssl: %r', db, user, host, port, socket, charset, local_infile, ssl, - ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename ) conv = conversions.copy() conv.update({ @@ -107,26 +89,16 @@ def connect(self, database=None, user=None, password=None, host=None, defer_connect = False - if ssh_host: - defer_connect = True - conn = pymysql.connect( database=db, user=user, password=password, host=host, port=port, unix_socket=socket, use_unicode=True, charset=charset, autocommit=True, client_flag=pymysql.constants.CLIENT.INTERACTIVE, local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli", - defer_connect=defer_connect + defer_connect=self.ssh_client is not None ) - if ssh_host: - client = paramiko.SSHClient() - client.load_system_host_keys() - client.set_missing_host_key_policy(paramiko.WarningPolicy()) - client.connect( - ssh_host, ssh_port, ssh_user, ssh_password, - key_filename=ssh_key_filename - ) - chan = client.get_transport().open_channel( + if self.ssh_client: + chan = self.ssh_client.get_transport().open_channel( 'direct-tcpip', (host, port), ('0.0.0.0', 0), diff --git a/test/conftest.py b/test/conftest.py index cf6d721b..d6361a01 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,4 +1,6 @@ import pytest + +from mycli.packages.ssh_client import create_ssh_client from .utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db, db_connection, SSH_USER, SSH_HOST, SSH_PORT) import mycli.sqlexecute @@ -21,9 +23,13 @@ def cursor(connection): @pytest.fixture def executor(connection): + if SSH_HOST: + ssh_client = create_ssh_client(SSH_HOST, SSH_PORT, SSH_USER) + else: + ssh_client = None + return mycli.sqlexecute.SQLExecute( database='_test_db', user=USER, host=HOST, password=PASSWORD, port=PORT, socket=None, charset=CHARSET, - local_infile=False, ssl=None, ssh_user=SSH_USER, ssh_host=SSH_HOST, - ssh_port=SSH_PORT, ssh_password=None, ssh_key_filename=None + local_infile=False, ssl=None, ssh_client=ssh_client ) diff --git a/test/test_main.py b/test/test_main.py index 3f92bd1b..8ecefaa3 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -434,6 +434,7 @@ def __init__(self, **args): self.logger = Logger() self.destructive_warning = False self.formatter = Formatter() + self._ssh_client = None def connect(self, **args): MockMyCli.connect_args = args @@ -441,8 +442,23 @@ def connect(self, **args): def run_query(self, query, new_line=True): pass + @property + def ssh_client(self): + pass + + @ssh_client.setter + def ssh_client(self, client): + MockMyCli._ssh_client = client + + def mock_create_ssh_client(*args): + return namedtuple( + 'SSHConf', + ('host', 'port', 'user', 'password', 'key_filename') + )(*args) + import mycli.main monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + monkeypatch.setattr(mycli.main, 'create_ssh_client', mock_create_ssh_client) runner = CliRunner() # Setup temporary configuration @@ -466,10 +482,10 @@ def run_query(self, query, new_line=True): assert result.exit_code == 0, result.output + \ " " + str(result.exception) assert \ - MockMyCli.connect_args["ssh_user"] == "joe" and \ - MockMyCli.connect_args["ssh_host"] == "test.example.com" and \ - MockMyCli.connect_args["ssh_port"] == 22222 and \ - MockMyCli.connect_args["ssh_key_filename"] == os.getenv( + MockMyCli._ssh_client.user == "joe" and \ + MockMyCli._ssh_client.host == "test.example.com" and \ + MockMyCli._ssh_client.port == 22222 and \ + MockMyCli._ssh_client.key_filename == os.getenv( "HOME") + "/.ssh/gateway" # When a user supplies a ssh config host as argument to mycli, @@ -488,7 +504,7 @@ def run_query(self, query, new_line=True): assert result.exit_code == 0, result.output + \ " " + str(result.exception) assert \ - MockMyCli.connect_args["ssh_user"] == "arg_user" and \ - MockMyCli.connect_args["ssh_host"] == "arg_host" and \ - MockMyCli.connect_args["ssh_port"] == 3 and \ - MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key" + MockMyCli._ssh_client.user == "arg_user" and \ + MockMyCli._ssh_client.host == "arg_host" and \ + MockMyCli._ssh_client.port == 3 and \ + MockMyCli._ssh_client.key_filename == "/path/to/key"