diff --git a/CHANGELOG.rst b/CHANGELOG.rst index cf94d9c938..5dfbc6ed12 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -10,6 +10,11 @@ Added to ensure update to the models require schema to be regenerated. (new feature) * Improved st2sensor service logging message when a sensor will not be loaded when assigned to a different partition (@punkrokk) +* Add support for a configurable connect timeout for SSH connections as requested in #4715 + by adding the new configuration parameter ``ssh_connect_timeout`` to the ``ssh_runner`` + group in st2.conf. (new feature) #4914 + + This option was requested by Harry Lee (@tclh123) and contributed by Marcel Weinberg (@winem). Fixed ~~~~~ @@ -52,7 +57,6 @@ Removed Added ~~~~~ - * Add support for blacklisting / whitelisting hosts to the HTTP runner by adding new ``url_hosts_blacklist`` and ``url_hosts_whitelist`` runner attribute. (new feature) #4757 diff --git a/conf/st2.conf.sample b/conf/st2.conf.sample index bd99093c24..32c3e60148 100644 --- a/conf/st2.conf.sample +++ b/conf/st2.conf.sample @@ -299,14 +299,16 @@ logging = /etc/st2/logging.sensorcontainer.conf sensor_node_name = sensornode1 [ssh_runner] +# Path to the ssh config file. +ssh_config_file_path = ~/.ssh/config # Max number of parallel remote SSH actions that should be run. Works only with Paramiko SSH runner. max_parallel_actions = 50 +# Max time in seconds to establish the SSH connection. +ssh_connect_timeout = 60 # Location of the script on the remote filesystem. remote_dir = /tmp # Use the .ssh/config file. Useful to override ports etc. use_ssh_config = False -# Path to the ssh config file. -ssh_config_file_path = ~/.ssh/config # How partial success of actions run on multiple nodes should be treated. allow_partial_failure = False diff --git a/st2actions/tests/unit/test_parallel_ssh.py b/st2actions/tests/unit/test_parallel_ssh.py index 89fb4e75bd..24faa849d6 100644 --- a/st2actions/tests/unit/test_parallel_ssh.py +++ b/st2actions/tests/unit/test_parallel_ssh.py @@ -152,7 +152,8 @@ def test_run_command_timeout(self): connect=True) mock_run = Mock(side_effect=SSHCommandTimeoutError(cmd='pwd', timeout=10, stdout='a', - stderr='b')) + stderr='b', + ssh_connect_timeout=30)) for host in hosts: hostname, _ = client._get_host_port_info(host) host_client = client._hosts_client[host] diff --git a/st2actions/tests/unit/test_paramiko_ssh.py b/st2actions/tests/unit/test_paramiko_ssh.py index e7e136c792..daac166344 100644 --- a/st2actions/tests/unit/test_paramiko_ssh.py +++ b/st2actions/tests/unit/test_paramiko_ssh.py @@ -43,12 +43,13 @@ def setUp(self): """ cfg.CONF.set_override(name='ssh_key_file', override=None, group='system_user') cfg.CONF.set_override(name='use_ssh_config', override=False, group='ssh_runner') + cfg.CONF.set_override(name='ssh_connect_timeout', override=30, group='ssh_runner') conn_params = {'hostname': 'dummy.host.org', 'port': 8822, 'username': 'ubuntu', 'key_files': '~/.ssh/ubuntu_ssh', - 'timeout': '600'} + 'timeout': 30} self.ssh_cli = ParamikoSSHClient(**conn_params) @patch('paramiko.SSHClient', Mock) @@ -108,7 +109,7 @@ def test_create_with_password(self): 'allow_agent': False, 'hostname': 'dummy.host.org', 'look_for_keys': False, - 'timeout': 60, + 'timeout': 30, 'port': 22} mock.client.connect.assert_called_once_with(**expected_conn) @@ -127,7 +128,7 @@ def test_deprecated_key_argument(self): 'hostname': 'dummy.host.org', 'look_for_keys': False, 'key_filename': 'id_rsa', - 'timeout': 60, + 'timeout': 30, 'port': 22} mock.client.connect.assert_called_once_with(**expected_conn) @@ -167,7 +168,7 @@ def test_key_material_argument(self): 'hostname': 'dummy.host.org', 'look_for_keys': False, 'pkey': pkey, - 'timeout': 60, + 'timeout': 30, 'port': 22} mock.client.connect.assert_called_once_with(**expected_conn) @@ -231,7 +232,7 @@ def test_key_with_passphrase_success(self): 'hostname': 'dummy.host.org', 'look_for_keys': False, 'pkey': pkey, - 'timeout': 60, + 'timeout': 30, 'port': 22} mock.client.connect.assert_called_once_with(**expected_conn) @@ -249,7 +250,7 @@ def test_key_with_passphrase_success(self): 'look_for_keys': False, 'key_filename': path, 'password': 'testphrase', - 'timeout': 60, + 'timeout': 30, 'port': 22} mock.client.connect.assert_called_once_with(**expected_conn) @@ -325,7 +326,7 @@ def test_create_with_key(self): 'hostname': 'dummy.host.org', 'look_for_keys': False, 'key_filename': 'id_rsa', - 'timeout': 60, + 'timeout': 30, 'port': 22} mock.client.connect.assert_called_once_with(**expected_conn) @@ -345,7 +346,7 @@ def test_create_with_key_via_bastion(self): 'hostname': 'bastion.host.org', 'look_for_keys': False, 'key_filename': 'id_rsa', - 'timeout': 60, + 'timeout': 30, 'port': 22} mock.bastion_client.connect.assert_called_once_with(**expected_bastion_conn) @@ -354,7 +355,7 @@ def test_create_with_key_via_bastion(self): 'hostname': 'dummy.host.org', 'look_for_keys': False, 'key_filename': 'id_rsa', - 'timeout': 60, + 'timeout': 30, 'port': 22, 'sock': mock.bastion_socket} mock.client.connect.assert_called_once_with(**expected_conn) @@ -376,7 +377,7 @@ def test_create_with_password_and_key(self): 'hostname': 'dummy.host.org', 'look_for_keys': False, 'key_filename': 'id_rsa', - 'timeout': 60, + 'timeout': 30, 'port': 22} mock.client.connect.assert_called_once_with(**expected_conn) @@ -417,7 +418,7 @@ def test_create_without_credentials_use_default_key(self): 'key_filename': 'stanley_rsa', 'allow_agent': False, 'look_for_keys': False, - 'timeout': 60, + 'timeout': 30, 'port': 22} mock.client.connect.assert_called_once_with(**expected_conn) @@ -446,7 +447,7 @@ def test_basic_usage_absolute_path(self): 'allow_agent': False, 'hostname': 'dummy.host.org', 'look_for_keys': False, - 'timeout': '600', + 'timeout': 28, 'port': 8822} mock_cli.connect.assert_called_once_with(**expected_conn) diff --git a/st2common/st2common/config.py b/st2common/st2common/config.py index e0cbb92278..18da4c594b 100644 --- a/st2common/st2common/config.py +++ b/st2common/st2common/config.py @@ -421,7 +421,10 @@ def register_opts(ignore_errors=False): help='Use the .ssh/config file. Useful to override ports etc.'), cfg.StrOpt( 'ssh_config_file_path', default='~/.ssh/config', - help='Path to the ssh config file.') + help='Path to the ssh config file.'), + cfg.IntOpt( + 'ssh_connect_timeout', default=60, + help='Max time in seconds to establish the SSH connection.') ] do_register_opts(ssh_runner_opts, group='ssh_runner') diff --git a/st2common/st2common/runners/paramiko_ssh.py b/st2common/st2common/runners/paramiko_ssh.py index b96098a89e..3523717424 100644 --- a/st2common/st2common/runners/paramiko_ssh.py +++ b/st2common/st2common/runners/paramiko_ssh.py @@ -49,7 +49,7 @@ class SSHCommandTimeoutError(Exception): Exception which is raised when an SSH command times out. """ - def __init__(self, cmd, timeout, stdout=None, stderr=None): + def __init__(self, cmd, timeout, ssh_connect_timeout, stdout=None, stderr=None): """ :param stdout: Stdout which was consumed until the timeout occured. :type stdout: ``str`` @@ -59,14 +59,16 @@ def __init__(self, cmd, timeout, stdout=None, stderr=None): """ self.cmd = cmd self.timeout = timeout + self.ssh_connect_timeout = ssh_connect_timeout self.stdout = stdout self.stderr = stderr - self.message = 'Command didn\'t finish in %s seconds' % (timeout) + self.message = ('Command didn\'t finish in %s seconds or the SSH connection ' + 'did not succeed in %s seconds' % (timeout, ssh_connect_timeout)) super(SSHCommandTimeoutError, self).__init__(self.message) def __repr__(self): - return ('' % - (self.cmd, self.timeout)) + return ('' % + (self.cmd, self.timeout, self.ssh_connect_timeout)) def __str__(self): return self.message @@ -83,9 +85,6 @@ class ParamikoSSHClient(object): # How long to sleep while waiting for command to finish to prevent busy waiting SLEEP_DELAY = 0.2 - # Connect socket timeout - CONNECT_TIMEOUT = 60 - def __init__(self, hostname, port=DEFAULT_SSH_PORT, username=None, password=None, bastion_host=None, key_files=None, key_material=None, timeout=None, passphrase=None, handle_stdout_line_func=None, handle_stderr_line_func=None): @@ -105,10 +104,11 @@ def __init__(self, hostname, port=DEFAULT_SSH_PORT, username=None, password=None self.username = username self.password = password self.key_files = key_files - self.timeout = timeout or ParamikoSSHClient.CONNECT_TIMEOUT + self.timeout = timeout self.key_material = key_material self.bastion_host = bastion_host self.passphrase = passphrase + self.ssh_connect_timeout = cfg.CONF.ssh_runner.ssh_connect_timeout self._handle_stdout_line_func = handle_stdout_line_func self._handle_stderr_line_func = handle_stderr_line_func @@ -116,6 +116,11 @@ def __init__(self, hostname, port=DEFAULT_SSH_PORT, username=None, password=None cfg.CONF.ssh_runner.ssh_config_file_path or '~/.ssh/config' ) + + if self.timeout and int(self.ssh_connect_timeout) > int(self.timeout) - 2: + # the connect timeout should not be greater than the action timeout + self.ssh_connect_timeout = int(self.timeout) - 2 + self.logger = logging.getLogger(__name__) self.client = None @@ -415,8 +420,9 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False): stdout = sanitize_output(stdout.getvalue(), uses_pty=uses_pty) stderr = sanitize_output(stderr.getvalue(), uses_pty=uses_pty) - raise SSHCommandTimeoutError(cmd=cmd, timeout=timeout, stdout=stdout, - stderr=stderr) + raise SSHCommandTimeoutError(cmd=cmd, timeout=timeout, + ssh_connect_timeout=self.ssh_connect_timeout, + stdout=stdout, stderr=stderr) stdout_data = self._consume_stdout(chan=chan, call_line_handler_func=call_line_handler_func) @@ -632,7 +638,7 @@ def _connect(self, host, socket=None): conninfo = {'hostname': host, 'allow_agent': False, 'look_for_keys': False, - 'timeout': self.timeout} + 'timeout': self.ssh_connect_timeout} ssh_config_file_info = {} if cfg.CONF.ssh_runner.use_ssh_config: @@ -701,7 +707,7 @@ def _connect(self, host, socket=None): conninfo['look_for_keys'] = True extra = {'_hostname': host, '_port': self.port, - '_username': self.username, '_timeout': self.timeout} + '_username': self.username, '_timeout': self.ssh_connect_timeout} self.logger.debug('Connecting to server', extra=extra) self.socket = socket or ssh_config_file_info.get('sock', None)