Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse ssh connection #869

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Features:
---------

* Add an option `--init-command` to execute SQL after connecting (Thanks: [KITAGAWA Yasutaka]).
* Reuse the same SSH connection in both main thread and completion thread (Thanks: [Georgy Frolov]).

1.22.2
======
Expand Down Expand Up @@ -785,3 +786,4 @@ Bug Fixes:
[Georgy Frolov]: https://github.com/pasenor
[Zach DeCook]: https://zachdecook.com
[laixintao]: https://github.com/laixintao
[Nathan Huang]: https://github.com/hxueh
3 changes: 1 addition & 2 deletions mycli/completion_refresher.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def _bg_refresh(self, sqlexecute, callbacks, completer_options):
e = sqlexecute
executor = SQLExecute(e.dbname, e.user, e.password, e.host, e.port,
e.socket, e.charset, e.local_infile, e.ssl,
e.ssh_user, e.ssh_host, e.ssh_port,
e.ssh_password, e.ssh_key_filename)
ssh_client=e.ssh_client)

# If callbacks is a single function then push it into a list.
if callable(callbacks):
Expand Down
81 changes: 32 additions & 49 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@
from prompt_toolkit.history import FileHistory
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory

from mycli.packages.ssh_client import create_ssh_client
from .packages.special.main import NO_QUERY
from .packages.prompt_utils import confirm, confirm_destructive_query
from .packages.tabular_output import sql_format
from .packages import special
from .packages import special, ssh_client
from .packages.special.favoritequeries import FavoriteQueries
from .sqlcompleter import SQLCompleter
from .clitoolbar import create_toolbar_tokens_func
Expand Down Expand Up @@ -66,11 +67,6 @@
from urllib.parse import unquote


try:
import paramiko
except ImportError:
from mycli.packages.paramiko_stub import paramiko

# Query tuples are used for maintaining history
Query = namedtuple('Query', ['query', 'successful', 'mutating'])

Expand Down Expand Up @@ -201,6 +197,8 @@ def __init__(self, sqlexecute=None, prompt=None,

self.prompt_app = None

self.ssh_client = None

def register_special_commands(self):
special.register_special_command(self.change_db, 'use',
'\\u', 'Change to a new database.', aliases=('\\u',))
Expand Down Expand Up @@ -361,9 +359,7 @@ def merge_ssl_with_cnf(self, ssl, cnf):
return merged

def connect(self, database='', user='', passwd='', host='', port='',
socket='', charset='', local_infile='', ssl='',
ssh_user='', ssh_host='', ssh_port='',
ssh_password='', ssh_key_filename='', init_command=''):
socket='', charset='', local_infile='', ssl=None, init_command=''):

cnf = {'database': None,
'user': None,
Expand All @@ -387,7 +383,7 @@ def connect(self, database='', user='', passwd='', host='', port='',

database = database or cnf['database']
# Socket interface not supported for SSH connections
if port or host or ssh_host or ssh_port:
if port or host or self.ssh_client:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On rebase this should become a little different, like

if port or (host and host != 'localhost') or self.ssh_client:
            socket = ''

socket = ''
else:
socket = socket or cnf['socket'] or guess_socket_location()
Expand Down Expand Up @@ -419,17 +415,16 @@ def _connect():
try:
self.sqlexecute = SQLExecute(
database, user, passwd, host, port, socket, charset,
local_infile, ssl, ssh_user, ssh_host, ssh_port,
ssh_password, ssh_key_filename, init_command
local_infile, ssl, init_command, ssh_client=self.ssh_client
)
except OperationalError as e:
if ('Access denied for user' in e.args[1]):
new_passwd = click.prompt('Password', hide_input=True,
show_default=False, type=str, err=True)
self.sqlexecute = SQLExecute(
database, user, new_passwd, host, port, socket,
charset, local_infile, ssl, ssh_user, ssh_host,
ssh_port, ssh_password, ssh_key_filename, init_command
charset, local_infile, ssl, init_command,
ssh_client=self.ssh_client
)
else:
raise e
Expand Down Expand Up @@ -1098,16 +1093,22 @@ def cli(database, user, host, port, socket, password, dbname,
else:
click.secho(alias)
sys.exit(0)

if list_ssh_config:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side note: the list_ssh_config feature is unreliable if the SSH config file uses certain features such as Match. Listing the hosts is not really in the scope of a tool such as mycli, and we ought to consider removing it.

ssh_config = read_ssh_config(ssh_config_path)
for host in ssh_config.get_hostnames():
try:
hosts = ssh_client.get_config_hosts(ssh_config_path)
except ssh_client.SSHException as e:
click.secho(str(e), err=True, fg='red')
sys.exit(1)

for host, hostname in hosts.items():
if verbose:
host_config = ssh_config.lookup(host)
click.secho("{} : {}".format(
host, host_config.get('hostname')))
host, hostname))
else:
click.secho(host)
sys.exit(0)

# Choose which ever one has a valid value.
database = dbname or database

Expand Down Expand Up @@ -1159,9 +1160,14 @@ def cli(database, user, host, port, socket, password, dbname,
port = uri.port

if ssh_config_host:
ssh_config = read_ssh_config(
ssh_config_path
).lookup(ssh_config_host)
try:
ssh_config = ssh_client.read_config_file(
ssh_config_path
).lookup(ssh_config_host)
except ssh_client.SSHException as e:
click.secho(str(e), err=True, fg='red')
sys.exit(1)

ssh_host = ssh_host if ssh_host else ssh_config.get('hostname')
ssh_user = ssh_user if ssh_user else ssh_config.get('user')
if ssh_config.get('port') and ssh_port == 22:
Expand All @@ -1170,7 +1176,10 @@ def cli(database, user, host, port, socket, password, dbname,
ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config.get(
'identityfile', [None])[0]

ssh_key_filename = ssh_key_filename and os.path.expanduser(ssh_key_filename)
if ssh_host:
mycli.ssh_client = create_ssh_client(
ssh_host, ssh_port, ssh_user, ssh_password, ssh_key_filename
)

mycli.connect(
database=database,
Expand All @@ -1181,12 +1190,7 @@ def cli(database, user, host, port, socket, password, dbname,
socket=socket,
local_infile=local_infile,
ssl=ssl,
ssh_user=ssh_user,
ssh_host=ssh_host,
ssh_port=ssh_port,
ssh_password=ssh_password,
ssh_key_filename=ssh_key_filename,
init_command=init_command
init_command=init_command,
)

mycli.logger.debug('Launch Params: \n'
Expand Down Expand Up @@ -1305,26 +1309,5 @@ def edit_and_execute(event):
buff.open_in_editor(validate_and_handle=False)


def read_ssh_config(ssh_config_path):
ssh_config = paramiko.config.SSHConfig()
try:
with open(ssh_config_path) as f:
ssh_config.parse(f)
# Paramiko prior to version 2.7 raises Exception on parse errors.
# In 2.7 it has become paramiko.ssh_exception.SSHException,
# but let's catch everything for compatibility
except Exception as err:
click.secho(
f'Could not parse SSH configuration file {ssh_config_path}:\n{err} ',
err=True, fg='red'
)
sys.exit(1)
except FileNotFoundError as e:
click.secho(str(e), err=True, fg='red')
sys.exit(1)
else:
return ssh_config


if __name__ == "__main__":
cli()
1 change: 1 addition & 0 deletions mycli/packages/ssh_client/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .client import get_config_hosts, create_ssh_client, SSHException, read_config_file
46 changes: 46 additions & 0 deletions mycli/packages/ssh_client/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""A very thin wrapper around paramiko, mostly to keep all SSH-related
functionality in one place."""
from io import open

try:
import paramiko
except ImportError:
from mycli.packages.paramiko_stub import paramiko


class SSHException(Exception):
pass


def get_config_hosts(config_path):
config = read_config_file(config_path)
return {
host: config.lookup(host).get("hostname") for host in config.get_hostnames()
}


def create_ssh_client(ssh_host, ssh_port, ssh_user, ssh_password=None, ssh_key_filename=None) -> paramiko.SSHClient:
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
client.connect(
ssh_host, ssh_port, ssh_user, password=ssh_password, key_filename=ssh_key_filename
)
return client


def read_config_file(config_path) -> paramiko.SSHConfig:
ssh_config = paramiko.config.SSHConfig()
try:
with open(config_path) as f:
ssh_config.parse(f)
# Paramiko prior to version 2.7 raises Exception on parse errors.
# In 2.7 it has become paramiko.ssh_exception.SSHException,
# but let's catch everything for compatibility
except Exception as err:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except FileNotFoundError as e: should come ahead of more general Exception.

raise SSHException(
f"Could not parse SSH configuration file {config_path}:\n{err} ",
)
except FileNotFoundError as e:
raise SSHException(str(e))
return ssh_config
52 changes: 13 additions & 39 deletions mycli/sqlexecute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
from pymysql.converters import (convert_datetime,
convert_timedelta, convert_date, conversions,
decoders)
try:
import paramiko
except ImportError:
from mycli.packages.paramiko_stub import paramiko


_logger = logging.getLogger(__name__)

Expand All @@ -18,6 +15,7 @@
FIELD_TYPE.NULL: type(None)
})


class SQLExecute(object):

databases_query = '''SHOW DATABASES'''
Expand All @@ -41,8 +39,8 @@ class SQLExecute(object):
order by table_name,ordinal_position'''

def __init__(self, database, user, password, host, port, socket, charset,
local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password,
ssh_key_filename, init_command=None):
local_infile, ssl, init_command=None,
ssh_client=None):
self.dbname = database
self.user = user
self.password = password
Expand All @@ -54,18 +52,14 @@ def __init__(self, database, user, password, host, port, socket, charset,
self.ssl = ssl
self._server_type = None
self.connection_id = None
self.ssh_user = ssh_user
self.ssh_host = ssh_host
self.ssh_port = ssh_port
self.ssh_password = ssh_password
self.ssh_key_filename = ssh_key_filename
self.init_command = init_command
self.ssh_client = ssh_client

self.connect()

def connect(self, database=None, user=None, password=None, host=None,
port=None, socket=None, charset=None, local_infile=None,
ssl=None, ssh_host=None, ssh_port=None, ssh_user=None,
ssh_password=None, ssh_key_filename=None, init_command=None):
port=None, socket=None, charset=None, local_infile=None, ssl=None,
ssh_client=None, init_command=None):
db = (database or self.dbname)
user = (user or self.user)
password = (password or self.password)
Expand All @@ -75,11 +69,6 @@ def connect(self, database=None, user=None, password=None, host=None,
charset = (charset or self.charset)
local_infile = (local_infile or self.local_infile)
ssl = (ssl or self.ssl)
ssh_user = (ssh_user or self.ssh_user)
ssh_host = (ssh_host or self.ssh_host)
ssh_port = (ssh_port or self.ssh_port)
ssh_password = (ssh_password or self.ssh_password)
ssh_key_filename = (ssh_key_filename or self.ssh_key_filename)
init_command = (init_command or self.init_command)
_logger.debug(
'Connection DB Params: \n'
Expand All @@ -90,15 +79,9 @@ def connect(self, database=None, user=None, password=None, host=None,
'\tsocket: %r'
'\tcharset: %r'
'\tlocal_infile: %r'
'\tssl: %r'
'\tssh_user: %r'
'\tssh_host: %r'
'\tssh_port: %r'
'\tssh_password: %r'
'\tssh_key_filename: %r'
'\tssl: %r',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could still log these values, if we wanted to, right?

'\tinit_command: %r',
db, user, host, port, socket, charset, local_infile, ssl,
ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename,
init_command
)
conv = conversions.copy()
Expand All @@ -111,9 +94,6 @@ def connect(self, database=None, user=None, password=None, host=None,

defer_connect = False

if ssh_host:
defer_connect = True

client_flag = pymysql.constants.CLIENT.INTERACTIVE
if init_command and len(list(special.split_queries(init_command))) > 1:
client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS
Expand All @@ -123,18 +103,12 @@ def connect(self, database=None, user=None, password=None, host=None,
unix_socket=socket, use_unicode=True, charset=charset,
autocommit=True, client_flag=client_flag,
local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli",
defer_connect=defer_connect, init_command=init_command
defer_connect=self.ssh_client is not None,
init_command=init_command
)

if ssh_host:
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
client.connect(
ssh_host, ssh_port, ssh_user, ssh_password,
key_filename=ssh_key_filename
)
chan = client.get_transport().open_channel(
if self.ssh_client:
chan = self.ssh_client.get_transport().open_channel(
'direct-tcpip',
(host, port),
('0.0.0.0', 0),
Expand Down
10 changes: 8 additions & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pytest

from mycli.packages.ssh_client import create_ssh_client
from .utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db,
db_connection, SSH_USER, SSH_HOST, SSH_PORT)
import mycli.sqlexecute
Expand All @@ -21,9 +23,13 @@ def cursor(connection):

@pytest.fixture
def executor(connection):
if SSH_HOST:
ssh_client = create_ssh_client(SSH_HOST, SSH_PORT, SSH_USER)
else:
ssh_client = None

return mycli.sqlexecute.SQLExecute(
database='_test_db', user=USER,
host=HOST, password=PASSWORD, port=PORT, socket=None, charset=CHARSET,
local_infile=False, ssl=None, ssh_user=SSH_USER, ssh_host=SSH_HOST,
ssh_port=SSH_PORT, ssh_password=None, ssh_key_filename=None
local_infile=False, ssl=None, ssh_client=ssh_client
)
Loading