Skip to content

Commit

Permalink
Add support for SSH connection via aliases from ~/.ssh/config (#790)
Browse files Browse the repository at this point in the history
* Added support for SSH/SCP/SFTP configuration from the default OpenSSH config file.

* Added support for SSH/SCP/SFTP configuration from the default OpenSSH config file.

* Added tests for override of user and port from config

* Linting fixes

* simplify and clean up tests

* clean up ssh.py submodule

* fix linter issues

---------

Co-authored-by: Michael Penkov <m@penkov.dev>
  • Loading branch information
wbeardall and mpenkov committed Feb 20, 2024
1 parent bcc2335 commit 269c3a2
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 11 deletions.
144 changes: 136 additions & 8 deletions smart_open/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,16 @@
"""

import getpass
import os
import logging
import urllib.parse

from typing import (
Dict,
Callable,
Tuple,
)

try:
import paramiko
except ImportError:
Expand All @@ -52,11 +59,44 @@
'sftp://username@host/path/file',
)

#
# Global storage for SSH config files.
#
_SSH_CONFIG_FILES = [os.path.expanduser("~/.ssh/config")]


def _unquote(text):
return text and urllib.parse.unquote(text)


def _str2bool(string):
if string == "no":
return False
if string == "yes":
return True
raise ValueError(f"Expected 'yes' / 'no', got {string}.")


#
# The parameter names used by Paramiko (and smart_open) slightly differ to
# those used in ~/.ssh/config, so we use a mapping to bridge the gap.
#
# The keys are option names as they appear in Paramiko (and smart_open)
# The values are a tuples containing:
#
# 1. their corresponding names in the ~/.ssh/config file
# 2. a callable to convert the parameter value from a string to the appropriate type
#
_PARAMIKO_CONFIG_MAP: Dict[str, Tuple[str, Callable]] = {
"timeout": ("connecttimeout", float),
"compress": ("compression", _str2bool),
"gss_auth": ("gssapiauthentication", _str2bool),
"gss_kex": ("gssapikeyexchange", _str2bool),
"gss_deleg_creds": ("gssapidelegatecredentials", _str2bool),
"gss_trust_dns": ("gssapitrustdns", _str2bool),
}


def parse_uri(uri_as_string):
split_uri = urllib.parse.urlsplit(uri_as_string)
assert split_uri.scheme in SCHEMES
Expand All @@ -65,7 +105,7 @@ def parse_uri(uri_as_string):
uri_path=_unquote(split_uri.path),
user=_unquote(split_uri.username),
host=split_uri.hostname,
port=int(split_uri.port or DEFAULT_PORT),
port=int(split_uri.port) if split_uri.port else None,
password=_unquote(split_uri.password),
)

Expand All @@ -90,7 +130,98 @@ def _connect_ssh(hostname, username, port, password, transport_params):
return ssh


def open(path, mode='r', host=None, user=None, password=None, port=DEFAULT_PORT, transport_params=None):
def _maybe_fetch_config(host, username=None, password=None, port=None, transport_params=None):
# If all fields are set, return as-is.
if not any(arg is None for arg in (host, username, password, port, transport_params)):
return host, username, password, port, transport_params

if not host:
raise ValueError('you must specify the host to connect to')
if not transport_params:
transport_params = {}
if "connect_kwargs" not in transport_params:
transport_params["connect_kwargs"] = {}

# Attempt to load an OpenSSH config.
#
# Connections configured in this way are not guaranteed to perform exactly
# as they do in typical usage due to mismatches between the set of OpenSSH
# configuration options and those that Paramiko supports. We provide a best
# attempt, and support:
#
# - hostname -> address resolution
# - username inference
# - port inference
# - identityfile inference
# - connection timeout inference
# - compression selection
# - GSS configuration
#
connect_params = transport_params["connect_kwargs"]
config_files = [f for f in _SSH_CONFIG_FILES if os.path.exists(f)]
#
# This is the actual name of the host. The input host may actually be an
# alias.
#
actual_hostname = ""

for config_filename in config_files:
try:
cfg = paramiko.SSHConfig.from_path(config_filename)
except PermissionError:
continue

if host not in cfg.get_hostnames():
continue

cfg = cfg.lookup(host)
if username is None:
username = cfg.get("user", None)

if not actual_hostname:
actual_hostname = cfg["hostname"]

if port is None:
try:
port = int(cfg["port"])
except (IndexError, ValueError):
#
# Nb. ignore missing/invalid port numbers
#
pass

#
# Special case, as we can have multiple identity files, so we check
# that the identityfile list has len > 0. This should be redundant, but
# keeping it for safety.
#
if connect_params.get("key_filename") is None:
identityfile = cfg.get("identityfile", [])
if len(identityfile):
connect_params["key_filename"] = identityfile

for param_name, (sshcfg_name, from_str) in _PARAMIKO_CONFIG_MAP.items():
if connect_params.get(param_name) is None and sshcfg_name in cfg:
connect_params[param_name] = from_str(cfg[sshcfg_name])

#
# Continue working through other config files, if there are any,
# as they may contain more options for our host
#

if port is None:
port = DEFAULT_PORT

if not username:
username = getpass.getuser()

if actual_hostname:
host = actual_hostname

return host, username, password, port, transport_params


def open(path, mode='r', host=None, user=None, password=None, port=None, transport_params=None):
"""Open a file on a remote machine over SSH.
Expects authentication to be already set up via existing keys on the local machine.
Expand Down Expand Up @@ -125,12 +256,9 @@ def open(path, mode='r', host=None, user=None, password=None, port=DEFAULT_PORT,
If ``username`` or ``password`` are specified in *both* the uri and
``transport_params``, ``transport_params`` will take precedence
"""
if not host:
raise ValueError('you must specify the host to connect to')
if not user:
user = getpass.getuser()
if not transport_params:
transport_params = {}
host, user, password, port, transport_params = _maybe_fetch_config(
host, user, password, port, transport_params
)

key = (host, user)

Expand Down
11 changes: 11 additions & 0 deletions smart_open/tests/test_data/ssh.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Host another-host
HostName another-host-domain.com
User another-user
Port 2345
IdentityFile /path/to/key/file
ConnectTimeout 20
Compression yes
GSSAPIAuthentication no
GSSAPIKeyExchange no
GSSAPIDelegateCredentials no
GSSAPITrustDns no
6 changes: 3 additions & 3 deletions smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def test_scp(self):
self.assertEqual(uri.uri_path, '/path/to/file')
self.assertEqual(uri.user, 'user')
self.assertEqual(uri.host, 'host')
self.assertEqual(uri.port, 22)
self.assertEqual(uri.port, None)
self.assertEqual(uri.password, None)

def test_scp_with_pass(self):
Expand All @@ -325,7 +325,7 @@ def test_scp_with_pass(self):
self.assertEqual(uri.uri_path, '/path/to/file')
self.assertEqual(uri.user, 'user')
self.assertEqual(uri.host, 'host')
self.assertEqual(uri.port, 22)
self.assertEqual(uri.port, None)
self.assertEqual(uri.password, 'pass')

def test_sftp(self):
Expand All @@ -335,7 +335,7 @@ def test_sftp(self):
self.assertEqual(uri.uri_path, '/path/to/file')
self.assertEqual(uri.user, None)
self.assertEqual(uri.host, 'host')
self.assertEqual(uri.port, 22)
self.assertEqual(uri.port, None)
self.assertEqual(uri.password, None)

def test_sftp_with_user_and_pass(self):
Expand Down
59 changes: 59 additions & 0 deletions smart_open/tests/test_ssh.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# -*- coding: utf-8 -*-

import logging
import os
import unittest
from unittest import mock

from paramiko import SSHException

import smart_open.ssh

_TEST_DATA_PATH = os.path.join(os.path.dirname(__file__), "test_data")
_CONFIG_PATH = os.path.join(_TEST_DATA_PATH, "ssh.cfg")


def mock_ssh(func):
def wrapper(*args, **kwargs):
Expand All @@ -20,6 +24,13 @@ def wrapper(*args, **kwargs):


class SSHOpen(unittest.TestCase):
def setUp(self):
self._cfg_files = smart_open.ssh._SSH_CONFIG_FILES
smart_open.ssh._SSH_CONFIG_FILES = [_CONFIG_PATH]

def tearDown(self):
smart_open.ssh._SSH_CONFIG_FILES = self._cfg_files

@mock_ssh
def test_open(self, mock_connect, get_transp_mock):
smart_open.open("ssh://user:pass@some-host/")
Expand Down Expand Up @@ -68,6 +79,54 @@ def mocked_open_sftp():
mock_connect.assert_called_with("some-host", 22, username="user", password="pass")
mock_sftp.open.assert_called_once()

@mock_ssh
def test_open_with_openssh_config(self, mock_connect, get_transp_mock):
smart_open.open("ssh://another-host/")
mock_connect.assert_called_with(
"another-host-domain.com",
2345,
username="another-user",
key_filename=["/path/to/key/file"],
timeout=20.,
compress=True,
gss_auth=False,
gss_kex=False,
gss_deleg_creds=False,
gss_trust_dns=False,
)

@mock_ssh
def test_open_with_openssh_config_override_port(self, mock_connect, get_transp_mock):
smart_open.open("ssh://another-host:22/")
mock_connect.assert_called_with(
"another-host-domain.com",
22,
username="another-user",
key_filename=["/path/to/key/file"],
timeout=20.,
compress=True,
gss_auth=False,
gss_kex=False,
gss_deleg_creds=False,
gss_trust_dns=False,
)

@mock_ssh
def test_open_with_openssh_config_override_user(self, mock_connect, get_transp_mock):
smart_open.open("ssh://new-user@another-host/")
mock_connect.assert_called_with(
"another-host-domain.com",
2345,
username="new-user",
key_filename=["/path/to/key/file"],
timeout=20.,
compress=True,
gss_auth=False,
gss_kex=False,
gss_deleg_creds=False,
gss_trust_dns=False,
)


if __name__ == "__main__":
logging.basicConfig(format="%(asctime)s : %(levelname)s : %(message)s", level=logging.DEBUG)
Expand Down

0 comments on commit 269c3a2

Please sign in to comment.