Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse ssh connection #869

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Features:
* Add `-g` shortcut to option `--login-path`.
* Alt-Enter dispatches the command in multi-line mode.
* Allow to pass a file or FIFO path with --password-file when password is not specified or is failing (as suggested in this best-practice https://www.netmeister.org/blog/passing-passwords.html)
* Reuse the same SSH connection for both main thread and completion thread.

Internal:
---------
Expand Down Expand Up @@ -846,6 +847,7 @@ Bug Fixes:
[Georgy Frolov]: https://github.com/pasenor
[Zach DeCook]: https://zachdecook.com
[laixintao]: https://github.com/laixintao
[Nathan Huang]: https://github.com/hxueh
[mtorromeo]: https://github.com/mtorromeo
[mwcm]: https://github.com/mwcm
[xeron]: https://github.com/xeron
3 changes: 1 addition & 2 deletions mycli/completion_refresher.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def _bg_refresh(self, sqlexecute, callbacks, completer_options):
e = sqlexecute
executor = SQLExecute(e.dbname, e.user, e.password, e.host, e.port,
e.socket, e.charset, e.local_infile, e.ssl,
e.ssh_user, e.ssh_host, e.ssh_port,
e.ssh_password, e.ssh_key_filename)
ssh_client=e.ssh_client)

# If callbacks is a single function then push it into a list.
if callable(callbacks):
Expand Down
62 changes: 34 additions & 28 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@
from prompt_toolkit.history import FileHistory
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory

from mycli.packages.ssh_client import create_ssh_client
from .packages.special.main import NO_QUERY
from .packages.prompt_utils import confirm, confirm_destructive_query
from .packages.tabular_output import sql_format
from .packages import special
from .packages import special, ssh_client
from .packages.special.favoritequeries import FavoriteQueries
from .sqlcompleter import SQLCompleter
from .clitoolbar import create_toolbar_tokens_func
Expand Down Expand Up @@ -74,11 +75,6 @@
# Python < 3.7
import importlib_resources as resources

try:
import paramiko
except ImportError:
from mycli.packages.paramiko_stub import paramiko

# Query tuples are used for maintaining history
Query = namedtuple('Query', ['query', 'successful', 'mutating'])

Expand Down Expand Up @@ -211,6 +207,8 @@ def __init__(self, sqlexecute=None, prompt=None,

self.prompt_app = None

self.ssh_client = None

def register_special_commands(self):
special.register_special_command(self.change_db, 'use',
'\\u', 'Change to a new database.', aliases=('\\u',))
Expand Down Expand Up @@ -387,9 +385,8 @@ def merge_ssl_with_cnf(self, ssl, cnf):
return merged

def connect(self, database='', user='', passwd='', host='', port='',
socket='', charset='', local_infile='', ssl='',
ssh_user='', ssh_host='', ssh_port='',
ssh_password='', ssh_key_filename='', init_command='', password_file=''):
socket='', charset='', local_infile='', ssl=None, init_command='',
password_file=''):

cnf = {'database': None,
'user': None,
Expand Down Expand Up @@ -418,7 +415,7 @@ def connect(self, database='', user='', passwd='', host='', port='',
ssl = ssl or {}

port = port and int(port)
if not port:
if not port and not self.ssh_client:
port = 3306
if not host or host == 'localhost':
socket = (
Expand Down Expand Up @@ -455,8 +452,7 @@ def _connect():
try:
self.sqlexecute = SQLExecute(
database, user, passwd, host, port, socket, charset,
local_infile, ssl, ssh_user, ssh_host, ssh_port,
ssh_password, ssh_key_filename, init_command
local_infile, ssl, init_command, ssh_client=self.ssh_client
)
except OperationalError as e:
if e.args[0] == ERROR_CODE_ACCESS_DENIED:
Expand All @@ -467,8 +463,8 @@ def _connect():
show_default=False, type=str, err=True)
self.sqlexecute = SQLExecute(
database, user, new_passwd, host, port, socket,
charset, local_infile, ssl, ssh_user, ssh_host,
ssh_port, ssh_password, ssh_key_filename, init_command
charset, local_infile, ssl, init_command,
ssh_client=self.ssh_client
)
else:
raise e
Expand Down Expand Up @@ -1179,16 +1175,22 @@ def cli(database, user, host, port, socket, password, dbname,
else:
click.secho(alias)
sys.exit(0)

if list_ssh_config:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side note: the list_ssh_config feature is unreliable if the SSH config file uses certain features such as Match. Listing the hosts is not really in the scope of a tool such as mycli, and we ought to consider removing it.

ssh_config = read_ssh_config(ssh_config_path)
for host in ssh_config.get_hostnames():
try:
hosts = ssh_client.get_config_hosts(ssh_config_path)
except ssh_client.SSHException as e:
click.secho(str(e), err=True, fg='red')
sys.exit(1)

for host, hostname in hosts.items():
if verbose:
host_config = ssh_config.lookup(host)
click.secho("{} : {}".format(
host, host_config.get('hostname')))
host, hostname))
else:
click.secho(host)
sys.exit(0)

# Choose which ever one has a valid value.
database = dbname or database

Expand Down Expand Up @@ -1240,9 +1242,14 @@ def cli(database, user, host, port, socket, password, dbname,
port = uri.port

if ssh_config_host:
ssh_config = read_ssh_config(
ssh_config_path
).lookup(ssh_config_host)
try:
ssh_config = ssh_client.read_config_file(
ssh_config_path
).lookup(ssh_config_host)
except ssh_client.SSHException as e:
click.secho(str(e), err=True, fg='red')
sys.exit(1)

ssh_host = ssh_host if ssh_host else ssh_config.get('hostname')
ssh_user = ssh_user if ssh_user else ssh_config.get('user')
if ssh_config.get('port') and ssh_port == 22:
Expand All @@ -1251,7 +1258,10 @@ def cli(database, user, host, port, socket, password, dbname,
ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config.get(
'identityfile', [None])[0]

ssh_key_filename = ssh_key_filename and os.path.expanduser(ssh_key_filename)
if ssh_host:
mycli.ssh_client = create_ssh_client(
ssh_host, ssh_port, ssh_user, ssh_password, ssh_key_filename
)

mycli.connect(
database=database,
Expand All @@ -1262,14 +1272,9 @@ def cli(database, user, host, port, socket, password, dbname,
socket=socket,
local_infile=local_infile,
ssl=ssl,
ssh_user=ssh_user,
ssh_host=ssh_host,
ssh_port=ssh_port,
ssh_password=ssh_password,
ssh_key_filename=ssh_key_filename,
init_command=init_command,
charset=charset,
password_file=password_file
password_file=password_file,
)

mycli.logger.debug('Launch Params: \n'
Expand Down Expand Up @@ -1403,6 +1408,7 @@ def read_ssh_config(ssh_config_path):
except FileNotFoundError as e:
click.secho(str(e), err=True, fg='red')
sys.exit(1)

# Paramiko prior to version 2.7 raises Exception on parse errors.
# In 2.7 it has become paramiko.ssh_exception.SSHException,
# but let's catch everything for compatibility
Expand Down
1 change: 1 addition & 0 deletions mycli/packages/ssh_client/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .client import get_config_hosts, create_ssh_client, SSHException, read_config_file
57 changes: 57 additions & 0 deletions mycli/packages/ssh_client/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""A very thin wrapper around paramiko, mostly to keep all SSH-related
functionality in one place."""
from io import open
import logging

_logger = logging.getLogger(__name__)

try:
import paramiko
except ImportError:
from mycli.packages.paramiko_stub import paramiko


class SSHException(Exception):
pass


def get_config_hosts(config_path):
config = read_config_file(config_path)
return {
host: config.lookup(host).get("hostname") for host in config.get_hostnames()
}


def create_ssh_client(ssh_host, ssh_port, ssh_user, ssh_password=None, ssh_key_filename=None) -> paramiko.SSHClient:
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
_logger.debug(
f'Connecting to ssh server with \n'
' host = {ssh_host}\n'
' port = {ssh_port}\n'
' user = {ssh_user}\n'
' password = {ssh_password}\n'
' key_filename = {ssh_key_filename}\n'
)
client.connect(
ssh_host, ssh_port, ssh_user, password=ssh_password, key_filename=ssh_key_filename
)
return client


def read_config_file(config_path) -> paramiko.SSHConfig:
ssh_config = paramiko.config.SSHConfig()
try:
with open(config_path) as f:
ssh_config.parse(f)
except FileNotFoundError as e:
raise SSHException(str(e))
# Paramiko prior to version 2.7 raises Exception on parse errors.
# In 2.7 it has become paramiko.ssh_exception.SSHException,
# but let's catch everything for compatibility
except Exception as err:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except FileNotFoundError as e: should come ahead of more general Exception.

raise SSHException(
f"Could not parse SSH configuration file {config_path}:\n{err} ",
)
return ssh_config
54 changes: 14 additions & 40 deletions mycli/sqlexecute.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
from pymysql.converters import (convert_datetime,
convert_timedelta, convert_date, conversions,
decoders)
try:
import paramiko
except ImportError:
from mycli.packages.paramiko_stub import paramiko


_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -97,8 +94,8 @@ class SQLExecute(object):
order by table_name,ordinal_position'''

def __init__(self, database, user, password, host, port, socket, charset,
local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password,
ssh_key_filename, init_command=None):
local_infile, ssl, init_command=None,
ssh_client=None):
self.dbname = database
self.user = user
self.password = password
Expand All @@ -110,18 +107,14 @@ def __init__(self, database, user, password, host, port, socket, charset,
self.ssl = ssl
self.server_info = None
self.connection_id = None
self.ssh_user = ssh_user
self.ssh_host = ssh_host
self.ssh_port = ssh_port
self.ssh_password = ssh_password
self.ssh_key_filename = ssh_key_filename
self.init_command = init_command
self.ssh_client = ssh_client

self.connect()

def connect(self, database=None, user=None, password=None, host=None,
port=None, socket=None, charset=None, local_infile=None,
ssl=None, ssh_host=None, ssh_port=None, ssh_user=None,
ssh_password=None, ssh_key_filename=None, init_command=None):
port=None, socket=None, charset=None, local_infile=None, ssl=None,
ssh_client=None, init_command=None):
db = (database or self.dbname)
user = (user or self.user)
password = (password or self.password)
Expand All @@ -131,11 +124,6 @@ def connect(self, database=None, user=None, password=None, host=None,
charset = (charset or self.charset)
local_infile = (local_infile or self.local_infile)
ssl = (ssl or self.ssl)
ssh_user = (ssh_user or self.ssh_user)
ssh_host = (ssh_host or self.ssh_host)
ssh_port = (ssh_port or self.ssh_port)
ssh_password = (ssh_password or self.ssh_password)
ssh_key_filename = (ssh_key_filename or self.ssh_key_filename)
init_command = (init_command or self.init_command)
_logger.debug(
'Connection DB Params: \n'
Expand All @@ -147,15 +135,10 @@ def connect(self, database=None, user=None, password=None, host=None,
'\tcharset: %r'
'\tlocal_infile: %r'
'\tssl: %r'
'\tssh_user: %r'
'\tssh_host: %r'
'\tssh_port: %r'
'\tssh_password: %r'
'\tssh_key_filename: %r'
'\tinit_command: %r',
'\tinit_command: %r'
'\tusing ssh: %r',
db, user, host, port, socket, charset, local_infile, ssl,
ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename,
init_command
init_command, bool(ssh_client)
)
conv = conversions.copy()
conv.update({
Expand All @@ -167,9 +150,6 @@ def connect(self, database=None, user=None, password=None, host=None,

defer_connect = False

if ssh_host:
defer_connect = True

client_flag = pymysql.constants.CLIENT.INTERACTIVE
if init_command and len(list(special.split_queries(init_command))) > 1:
client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS
Expand All @@ -179,18 +159,12 @@ def connect(self, database=None, user=None, password=None, host=None,
unix_socket=socket, use_unicode=True, charset=charset,
autocommit=True, client_flag=client_flag,
local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli",
defer_connect=defer_connect, init_command=init_command
defer_connect=self.ssh_client is not None,
init_command=init_command
)

if ssh_host:
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
client.connect(
ssh_host, ssh_port, ssh_user, ssh_password,
key_filename=ssh_key_filename
)
chan = client.get_transport().open_channel(
if self.ssh_client:
chan = self.ssh_client.get_transport().open_channel(
'direct-tcpip',
(host, port),
('0.0.0.0', 0),
Expand Down
10 changes: 8 additions & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pytest

from mycli.packages.ssh_client import create_ssh_client
from .utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db,
db_connection, SSH_USER, SSH_HOST, SSH_PORT)
import mycli.sqlexecute
Expand All @@ -21,9 +23,13 @@ def cursor(connection):

@pytest.fixture
def executor(connection):
if SSH_HOST:
ssh_client = create_ssh_client(SSH_HOST, SSH_PORT, SSH_USER)
else:
ssh_client = None

return mycli.sqlexecute.SQLExecute(
database='_test_db', user=USER,
host=HOST, password=PASSWORD, port=PORT, socket=None, charset=CHARSET,
local_infile=False, ssl=None, ssh_user=SSH_USER, ssh_host=SSH_HOST,
ssh_port=SSH_PORT, ssh_password=None, ssh_key_filename=None
local_infile=False, ssl=None, ssh_client=ssh_client
)
Loading