diff --git a/changelog.md b/changelog.md index 965ef97b..21f9464b 100644 --- a/changelog.md +++ b/changelog.md @@ -13,6 +13,7 @@ Features: * Add `-g` shortcut to option `--login-path`. * Alt-Enter dispatches the command in multi-line mode. * Allow to pass a file or FIFO path with --password-file when password is not specified or is failing (as suggested in this best-practice https://www.netmeister.org/blog/passing-passwords.html) +* Reuse the same SSH connection for both main thread and completion thread. Internal: --------- @@ -846,6 +847,7 @@ Bug Fixes: [Georgy Frolov]: https://github.com/pasenor [Zach DeCook]: https://zachdecook.com [laixintao]: https://github.com/laixintao +[Nathan Huang]: https://github.com/hxueh [mtorromeo]: https://github.com/mtorromeo [mwcm]: https://github.com/mwcm [xeron]: https://github.com/xeron diff --git a/mycli/completion_refresher.py b/mycli/completion_refresher.py index e6c8dd07..a60068d9 100644 --- a/mycli/completion_refresher.py +++ b/mycli/completion_refresher.py @@ -51,8 +51,7 @@ def _bg_refresh(self, sqlexecute, callbacks, completer_options): e = sqlexecute executor = SQLExecute(e.dbname, e.user, e.password, e.host, e.port, e.socket, e.charset, e.local_infile, e.ssl, - e.ssh_user, e.ssh_host, e.ssh_port, - e.ssh_password, e.ssh_key_filename) + ssh_client=e.ssh_client) # If callbacks is a single function then push it into a list. if callable(callbacks): diff --git a/mycli/main.py b/mycli/main.py index 3f08e9c3..eefd1cf8 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -37,10 +37,11 @@ from prompt_toolkit.history import FileHistory from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from mycli.packages.ssh_client import create_ssh_client from .packages.special.main import NO_QUERY from .packages.prompt_utils import confirm, confirm_destructive_query from .packages.tabular_output import sql_format -from .packages import special +from .packages import special, ssh_client from .packages.special.favoritequeries import FavoriteQueries from .sqlcompleter import SQLCompleter from .clitoolbar import create_toolbar_tokens_func @@ -74,11 +75,6 @@ # Python < 3.7 import importlib_resources as resources -try: - import paramiko -except ImportError: - from mycli.packages.paramiko_stub import paramiko - # Query tuples are used for maintaining history Query = namedtuple('Query', ['query', 'successful', 'mutating']) @@ -211,6 +207,8 @@ def __init__(self, sqlexecute=None, prompt=None, self.prompt_app = None + self.ssh_client = None + def register_special_commands(self): special.register_special_command(self.change_db, 'use', '\\u', 'Change to a new database.', aliases=('\\u',)) @@ -387,9 +385,8 @@ def merge_ssl_with_cnf(self, ssl, cnf): return merged def connect(self, database='', user='', passwd='', host='', port='', - socket='', charset='', local_infile='', ssl='', - ssh_user='', ssh_host='', ssh_port='', - ssh_password='', ssh_key_filename='', init_command='', password_file=''): + socket='', charset='', local_infile='', ssl=None, init_command='', + password_file=''): cnf = {'database': None, 'user': None, @@ -418,7 +415,7 @@ def connect(self, database='', user='', passwd='', host='', port='', ssl = ssl or {} port = port and int(port) - if not port: + if not port and not self.ssh_client: port = 3306 if not host or host == 'localhost': socket = ( @@ -455,8 +452,7 @@ def _connect(): try: self.sqlexecute = SQLExecute( database, user, passwd, host, port, socket, charset, - local_infile, ssl, ssh_user, ssh_host, ssh_port, - ssh_password, ssh_key_filename, init_command + local_infile, ssl, init_command, ssh_client=self.ssh_client ) except OperationalError as e: if e.args[0] == ERROR_CODE_ACCESS_DENIED: @@ -467,8 +463,8 @@ def _connect(): show_default=False, type=str, err=True) self.sqlexecute = SQLExecute( database, user, new_passwd, host, port, socket, - charset, local_infile, ssl, ssh_user, ssh_host, - ssh_port, ssh_password, ssh_key_filename, init_command + charset, local_infile, ssl, init_command, + ssh_client=self.ssh_client ) else: raise e @@ -1179,16 +1175,22 @@ def cli(database, user, host, port, socket, password, dbname, else: click.secho(alias) sys.exit(0) + if list_ssh_config: - ssh_config = read_ssh_config(ssh_config_path) - for host in ssh_config.get_hostnames(): + try: + hosts = ssh_client.get_config_hosts(ssh_config_path) + except ssh_client.SSHException as e: + click.secho(str(e), err=True, fg='red') + sys.exit(1) + + for host, hostname in hosts.items(): if verbose: - host_config = ssh_config.lookup(host) click.secho("{} : {}".format( - host, host_config.get('hostname'))) + host, hostname)) else: click.secho(host) sys.exit(0) + # Choose which ever one has a valid value. database = dbname or database @@ -1240,9 +1242,14 @@ def cli(database, user, host, port, socket, password, dbname, port = uri.port if ssh_config_host: - ssh_config = read_ssh_config( - ssh_config_path - ).lookup(ssh_config_host) + try: + ssh_config = ssh_client.read_config_file( + ssh_config_path + ).lookup(ssh_config_host) + except ssh_client.SSHException as e: + click.secho(str(e), err=True, fg='red') + sys.exit(1) + ssh_host = ssh_host if ssh_host else ssh_config.get('hostname') ssh_user = ssh_user if ssh_user else ssh_config.get('user') if ssh_config.get('port') and ssh_port == 22: @@ -1251,7 +1258,10 @@ def cli(database, user, host, port, socket, password, dbname, ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config.get( 'identityfile', [None])[0] - ssh_key_filename = ssh_key_filename and os.path.expanduser(ssh_key_filename) + if ssh_host: + mycli.ssh_client = create_ssh_client( + ssh_host, ssh_port, ssh_user, ssh_password, ssh_key_filename + ) mycli.connect( database=database, @@ -1262,14 +1272,9 @@ def cli(database, user, host, port, socket, password, dbname, socket=socket, local_infile=local_infile, ssl=ssl, - ssh_user=ssh_user, - ssh_host=ssh_host, - ssh_port=ssh_port, - ssh_password=ssh_password, - ssh_key_filename=ssh_key_filename, init_command=init_command, charset=charset, - password_file=password_file + password_file=password_file, ) mycli.logger.debug('Launch Params: \n' @@ -1403,6 +1408,7 @@ def read_ssh_config(ssh_config_path): except FileNotFoundError as e: click.secho(str(e), err=True, fg='red') sys.exit(1) + # Paramiko prior to version 2.7 raises Exception on parse errors. # In 2.7 it has become paramiko.ssh_exception.SSHException, # but let's catch everything for compatibility diff --git a/mycli/packages/ssh_client/__init__.py b/mycli/packages/ssh_client/__init__.py new file mode 100644 index 00000000..216b8db1 --- /dev/null +++ b/mycli/packages/ssh_client/__init__.py @@ -0,0 +1 @@ +from .client import get_config_hosts, create_ssh_client, SSHException, read_config_file diff --git a/mycli/packages/ssh_client/client.py b/mycli/packages/ssh_client/client.py new file mode 100644 index 00000000..4403f297 --- /dev/null +++ b/mycli/packages/ssh_client/client.py @@ -0,0 +1,57 @@ +"""A very thin wrapper around paramiko, mostly to keep all SSH-related +functionality in one place.""" +from io import open +import logging + +_logger = logging.getLogger(__name__) + +try: + import paramiko +except ImportError: + from mycli.packages.paramiko_stub import paramiko + + +class SSHException(Exception): + pass + + +def get_config_hosts(config_path): + config = read_config_file(config_path) + return { + host: config.lookup(host).get("hostname") for host in config.get_hostnames() + } + + +def create_ssh_client(ssh_host, ssh_port, ssh_user, ssh_password=None, ssh_key_filename=None) -> paramiko.SSHClient: + client = paramiko.SSHClient() + client.load_system_host_keys() + client.set_missing_host_key_policy(paramiko.WarningPolicy()) + _logger.debug( + f'Connecting to ssh server with \n' + ' host = {ssh_host}\n' + ' port = {ssh_port}\n' + ' user = {ssh_user}\n' + ' password = {ssh_password}\n' + ' key_filename = {ssh_key_filename}\n' + ) + client.connect( + ssh_host, ssh_port, ssh_user, password=ssh_password, key_filename=ssh_key_filename + ) + return client + + +def read_config_file(config_path) -> paramiko.SSHConfig: + ssh_config = paramiko.config.SSHConfig() + try: + with open(config_path) as f: + ssh_config.parse(f) + except FileNotFoundError as e: + raise SSHException(str(e)) + # Paramiko prior to version 2.7 raises Exception on parse errors. + # In 2.7 it has become paramiko.ssh_exception.SSHException, + # but let's catch everything for compatibility + except Exception as err: + raise SSHException( + f"Could not parse SSH configuration file {config_path}:\n{err} ", + ) + return ssh_config diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index 94614387..36592cac 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -8,10 +8,7 @@ from pymysql.converters import (convert_datetime, convert_timedelta, convert_date, conversions, decoders) -try: - import paramiko -except ImportError: - from mycli.packages.paramiko_stub import paramiko + _logger = logging.getLogger(__name__) @@ -97,8 +94,8 @@ class SQLExecute(object): order by table_name,ordinal_position''' def __init__(self, database, user, password, host, port, socket, charset, - local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password, - ssh_key_filename, init_command=None): + local_infile, ssl, init_command=None, + ssh_client=None): self.dbname = database self.user = user self.password = password @@ -110,18 +107,14 @@ def __init__(self, database, user, password, host, port, socket, charset, self.ssl = ssl self.server_info = None self.connection_id = None - self.ssh_user = ssh_user - self.ssh_host = ssh_host - self.ssh_port = ssh_port - self.ssh_password = ssh_password - self.ssh_key_filename = ssh_key_filename self.init_command = init_command + self.ssh_client = ssh_client + self.connect() def connect(self, database=None, user=None, password=None, host=None, - port=None, socket=None, charset=None, local_infile=None, - ssl=None, ssh_host=None, ssh_port=None, ssh_user=None, - ssh_password=None, ssh_key_filename=None, init_command=None): + port=None, socket=None, charset=None, local_infile=None, ssl=None, + ssh_client=None, init_command=None): db = (database or self.dbname) user = (user or self.user) password = (password or self.password) @@ -131,11 +124,6 @@ def connect(self, database=None, user=None, password=None, host=None, charset = (charset or self.charset) local_infile = (local_infile or self.local_infile) ssl = (ssl or self.ssl) - ssh_user = (ssh_user or self.ssh_user) - ssh_host = (ssh_host or self.ssh_host) - ssh_port = (ssh_port or self.ssh_port) - ssh_password = (ssh_password or self.ssh_password) - ssh_key_filename = (ssh_key_filename or self.ssh_key_filename) init_command = (init_command or self.init_command) _logger.debug( 'Connection DB Params: \n' @@ -147,15 +135,10 @@ def connect(self, database=None, user=None, password=None, host=None, '\tcharset: %r' '\tlocal_infile: %r' '\tssl: %r' - '\tssh_user: %r' - '\tssh_host: %r' - '\tssh_port: %r' - '\tssh_password: %r' - '\tssh_key_filename: %r' - '\tinit_command: %r', + '\tinit_command: %r' + '\tusing ssh: %r', db, user, host, port, socket, charset, local_infile, ssl, - ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename, - init_command + init_command, bool(ssh_client) ) conv = conversions.copy() conv.update({ @@ -167,9 +150,6 @@ def connect(self, database=None, user=None, password=None, host=None, defer_connect = False - if ssh_host: - defer_connect = True - client_flag = pymysql.constants.CLIENT.INTERACTIVE if init_command and len(list(special.split_queries(init_command))) > 1: client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS @@ -179,18 +159,12 @@ def connect(self, database=None, user=None, password=None, host=None, unix_socket=socket, use_unicode=True, charset=charset, autocommit=True, client_flag=client_flag, local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli", - defer_connect=defer_connect, init_command=init_command + defer_connect=self.ssh_client is not None, + init_command=init_command ) - if ssh_host: - client = paramiko.SSHClient() - client.load_system_host_keys() - client.set_missing_host_key_policy(paramiko.WarningPolicy()) - client.connect( - ssh_host, ssh_port, ssh_user, ssh_password, - key_filename=ssh_key_filename - ) - chan = client.get_transport().open_channel( + if self.ssh_client: + chan = self.ssh_client.get_transport().open_channel( 'direct-tcpip', (host, port), ('0.0.0.0', 0), diff --git a/test/conftest.py b/test/conftest.py index d7d10ce3..1c3ae94a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,4 +1,6 @@ import pytest + +from mycli.packages.ssh_client import create_ssh_client from .utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db, db_connection, SSH_USER, SSH_HOST, SSH_PORT) import mycli.sqlexecute @@ -21,9 +23,13 @@ def cursor(connection): @pytest.fixture def executor(connection): + if SSH_HOST: + ssh_client = create_ssh_client(SSH_HOST, SSH_PORT, SSH_USER) + else: + ssh_client = None + return mycli.sqlexecute.SQLExecute( database='_test_db', user=USER, host=HOST, password=PASSWORD, port=PORT, socket=None, charset=CHARSET, - local_infile=False, ssl=None, ssh_user=SSH_USER, ssh_host=SSH_HOST, - ssh_port=SSH_PORT, ssh_password=None, ssh_key_filename=None + local_infile=False, ssl=None, ssh_client=ssh_client ) diff --git a/test/test_main.py b/test/test_main.py index 00fdc1bd..07d8f25a 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -433,6 +433,7 @@ def __init__(self, **args): self.logger = Logger() self.destructive_warning = False self.formatter = Formatter() + self._ssh_client = None def connect(self, **args): MockMyCli.connect_args = args @@ -440,8 +441,24 @@ def connect(self, **args): def run_query(self, query, new_line=True): pass + @property + def ssh_client(self): + pass + + @ssh_client.setter + def ssh_client(self, client): + MockMyCli._ssh_client = client + + def mock_create_ssh_client(*args): + return namedtuple( + 'SSHConf', + ('host', 'port', 'user', 'password', 'key_filename') + )(*args) + import mycli.main monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + monkeypatch.setattr(mycli.main, 'create_ssh_client', + mock_create_ssh_client) runner = CliRunner() # Setup temporary configuration @@ -465,10 +482,10 @@ def run_query(self, query, new_line=True): assert result.exit_code == 0, result.output + \ " " + str(result.exception) assert \ - MockMyCli.connect_args["ssh_user"] == "joe" and \ - MockMyCli.connect_args["ssh_host"] == "test.example.com" and \ - MockMyCli.connect_args["ssh_port"] == 22222 and \ - MockMyCli.connect_args["ssh_key_filename"] == os.getenv( + MockMyCli._ssh_client.user == "joe" and \ + MockMyCli._ssh_client.host == "test.example.com" and \ + MockMyCli._ssh_client.port == 22222 and \ + MockMyCli._ssh_client.key_filename == os.getenv( "HOME") + "/.ssh/gateway" # When a user supplies a ssh config host as argument to mycli, @@ -487,10 +504,10 @@ def run_query(self, query, new_line=True): assert result.exit_code == 0, result.output + \ " " + str(result.exception) assert \ - MockMyCli.connect_args["ssh_user"] == "arg_user" and \ - MockMyCli.connect_args["ssh_host"] == "arg_host" and \ - MockMyCli.connect_args["ssh_port"] == 3 and \ - MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key" + MockMyCli._ssh_client.user == "arg_user" and \ + MockMyCli._ssh_client.host == "arg_host" and \ + MockMyCli._ssh_client.port == 3 and \ + MockMyCli._ssh_client.key_filename == "/path/to/key" @dbtest