diff --git a/airflow/contrib/hooks/ssh_hook.py b/airflow/contrib/hooks/ssh_hook.py index f51f0fbd11948..0bc06c56a4e17 100755 --- a/airflow/contrib/hooks/ssh_hook.py +++ b/airflow/contrib/hooks/ssh_hook.py @@ -22,11 +22,12 @@ import getpass import os +import warnings import paramiko from paramiko.config import SSH_PORT +from sshtunnel import SSHTunnelForwarder -from contextlib import contextmanager from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook from airflow.utils.log.logging_mixin import LoggingMixin @@ -65,7 +66,7 @@ def __init__(self, username=None, password=None, key_file=None, - port=SSH_PORT, + port=None, timeout=10, keepalive_interval=30 ): @@ -75,162 +76,167 @@ def __init__(self, self.username = username self.password = password self.key_file = key_file + self.port = port self.timeout = timeout self.keepalive_interval = keepalive_interval + # Default values, overridable from Connection self.compress = True self.no_host_key_check = True + self.host_proxy = None + + # Placeholder for deprecated __enter__ self.client = None - self.port = port + + # Use connection to override defaults + if self.ssh_conn_id is not None: + conn = self.get_connection(self.ssh_conn_id) + if self.username is None: + self.username = conn.login + if self.password is None: + self.password = conn.password + if self.remote_host is None: + self.remote_host = conn.host + if self.port is None: + self.port = conn.port + if conn.extra is not None: + extra_options = conn.extra_dejson + self.key_file = extra_options.get("key_file") + + if "timeout" in extra_options: + self.timeout = int(extra_options["timeout"], 10) + + if "compress" in extra_options\ + and str(extra_options["compress"]).lower() == 'false': + self.compress = False + if "no_host_key_check" in extra_options\ + and\ + str(extra_options["no_host_key_check"]).lower() == 'false': + self.no_host_key_check = False + + if not self.remote_host: + raise AirflowException("Missing required param: remote_host") + + # Auto detecting username values from system + if not self.username: + self.log.debug( + "username to ssh to host: %s is not specified for connection id" + " %s. Using system's default provided by getpass.getuser()", + self.remote_host, self.ssh_conn_id + ) + self.username = getpass.getuser() + + user_ssh_config_filename = os.path.expanduser('~/.ssh/config') + if os.path.isfile(user_ssh_config_filename): + ssh_conf = paramiko.SSHConfig() + ssh_conf.parse(open(user_ssh_config_filename)) + host_info = ssh_conf.lookup(self.remote_host) + if host_info and host_info.get('proxycommand'): + self.host_proxy = paramiko.ProxyCommand(host_info.get('proxycommand')) + + if not (self.password or self.key_file): + if host_info and host_info.get('identityfile'): + self.key_file = host_info.get('identityfile')[0] + + self.port = self.port or SSH_PORT def get_conn(self): - if not self.client: - self.log.debug('Creating SSH client for conn_id: %s', self.ssh_conn_id) - if self.ssh_conn_id is not None: - conn = self.get_connection(self.ssh_conn_id) - if self.username is None: - self.username = conn.login - if self.password is None: - self.password = conn.password - if self.remote_host is None: - self.remote_host = conn.host - if conn.port is not None: - self.port = conn.port - if conn.extra is not None: - extra_options = conn.extra_dejson - self.key_file = extra_options.get("key_file") - - if "timeout" in extra_options: - self.timeout = int(extra_options["timeout"], 10) - - if "compress" in extra_options \ - and str(extra_options["compress"]).lower() == 'false': - self.compress = False - if "no_host_key_check" in extra_options \ - and \ - str(extra_options["no_host_key_check"]).lower() == 'false': - self.no_host_key_check = False - - if not self.remote_host: - raise AirflowException("Missing required param: remote_host") - - # Auto detecting username values from system - if not self.username: - self.log.debug( - "username to ssh to host: %s is not specified for connection id" - " %s. Using system's default provided by getpass.getuser()", - self.remote_host, self.ssh_conn_id - ) - self.username = getpass.getuser() - - host_proxy = None - user_ssh_config_filename = os.path.expanduser('~/.ssh/config') - if os.path.isfile(user_ssh_config_filename): - ssh_conf = paramiko.SSHConfig() - ssh_conf.parse(open(user_ssh_config_filename)) - host_info = ssh_conf.lookup(self.remote_host) - if host_info and host_info.get('proxycommand'): - host_proxy = paramiko.ProxyCommand(host_info.get('proxycommand')) - - if not (self.password or self.key_file): - if host_info and host_info.get('identityfile'): - self.key_file = host_info.get('identityfile')[0] - - try: - client = paramiko.SSHClient() - client.load_system_host_keys() - if self.no_host_key_check: - # Default is RejectPolicy - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - - if self.password and self.password.strip(): - client.connect(hostname=self.remote_host, - username=self.username, - password=self.password, - timeout=self.timeout, - compress=self.compress, - port=self.port, - sock=host_proxy) - else: - client.connect(hostname=self.remote_host, - username=self.username, - key_filename=self.key_file, - timeout=self.timeout, - compress=self.compress, - port=self.port, - sock=host_proxy) - - if self.keepalive_interval: - client.get_transport().set_keepalive(self.keepalive_interval) - - self.client = client - except paramiko.AuthenticationException as auth_error: - self.log.error( - "Auth failed while connecting to host: %s, error: %s", - self.remote_host, auth_error - ) - except paramiko.SSHException as ssh_error: - self.log.error( - "Failed connecting to host: %s, error: %s", - self.remote_host, ssh_error - ) - except Exception as error: - self.log.error( - "Error connecting to host: %s, error: %s", - self.remote_host, error - ) - return self.client - - @contextmanager - def create_tunnel(self, local_port, remote_port=None, remote_host="localhost"): """ - Creates a tunnel between two hosts. Like ssh -L :host:. - Remember to close() the returned "tunnel" object in order to clean up - after yourself when you are done with the tunnel. + Opens a ssh connection to the remote host. - :param local_port: - :type local_port: int - :param remote_port: - :type remote_port: int - :param remote_host: - :type remote_host: str - :return: + :return paramiko.SSHClient object """ - import subprocess - # this will ensure the connection to the ssh.remote_host from where the tunnel - # is getting created - self.get_conn() - - tunnel_host = "{0}:{1}:{2}".format(local_port, remote_host, remote_port) - - ssh_cmd = ["ssh", "{0}@{1}".format(self.username, self.remote_host), - "-o", "ControlMaster=no", - "-o", "UserKnownHostsFile=/dev/null", - "-o", "StrictHostKeyChecking=no"] - - ssh_tunnel_cmd = ["-L", tunnel_host, - "echo -n ready && cat" - ] - - ssh_cmd += ssh_tunnel_cmd - self.log.debug("Creating tunnel with cmd: %s", ssh_cmd) - - proc = subprocess.Popen(ssh_cmd, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - close_fds=True) - ready = proc.stdout.read(5) - assert ready == b"ready", \ - "Did not get 'ready' from remote, got '{0}' instead".format(ready) - yield - proc.communicate() - assert proc.returncode == 0, \ - "Tunnel process did unclean exit (returncode {}".format(proc.returncode) + self.log.debug('Creating SSH client for conn_id: %s', self.ssh_conn_id) + client = paramiko.SSHClient() + client.load_system_host_keys() + if self.no_host_key_check: + # Default is RejectPolicy + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + if self.password and self.password.strip(): + client.connect(hostname=self.remote_host, + username=self.username, + password=self.password, + key_filename=self.key_file, + timeout=self.timeout, + compress=self.compress, + port=self.port, + sock=self.host_proxy) + else: + client.connect(hostname=self.remote_host, + username=self.username, + key_filename=self.key_file, + timeout=self.timeout, + compress=self.compress, + port=self.port, + sock=self.host_proxy) + + if self.keepalive_interval: + client.get_transport().set_keepalive(self.keepalive_interval) + + self.client = client + return client def __enter__(self): + warnings.warn('The contextmanager of SSHHook is deprecated.' + 'Please use get_conn() as a contextmanager instead.' + 'This method will be removed in Airflow 2.0', + category=DeprecationWarning) return self def __exit__(self, exc_type, exc_val, exc_tb): if self.client is not None: self.client.close() + self.client = None + + def get_tunnel(self, remote_port, remote_host="localhost", local_port=None): + """ + Creates a tunnel between two hosts. Like ssh -L :host:. + + :param remote_port: The remote port to create a tunnel to + :type remote_port: int + :param remote_host: The remote host to create a tunnel to (default localhost) + :type remote_host: str + :param local_port: The local port to attach the tunnel to + :type local_port: int + + :return: sshtunnel.SSHTunnelForwarder object + """ + + if local_port: + local_bind_address = ('localhost', local_port) + else: + local_bind_address = ('localhost',) + + if self.password and self.password.strip(): + client = SSHTunnelForwarder(self.remote_host, + ssh_port=self.port, + ssh_username=self.username, + ssh_password=self.password, + ssh_pkey=self.key_file, + ssh_proxy=self.host_proxy, + local_bind_address=local_bind_address, + remote_bind_address=(remote_host, remote_port), + logger=self.log) + else: + client = SSHTunnelForwarder(self.remote_host, + ssh_port=self.port, + ssh_username=self.username, + ssh_pkey=self.key_file, + ssh_proxy=self.host_proxy, + local_bind_address=local_bind_address, + remote_bind_address=(remote_host, remote_port), + host_pkey_directories=[], + logger=self.log) + + return client + + def create_tunnel(self, local_port, remote_port=None, remote_host="localhost"): + warnings.warn('SSHHook.create_tunnel is deprecated, Please' + 'use get_tunnel() instead. But please note that the' + 'order of the parameters have changed' + 'This method will be removed in Airflow 2.0', + category=DeprecationWarning) + + return self.get_tunnel(remote_port, remote_host, local_port) diff --git a/airflow/contrib/operators/sftp_operator.py b/airflow/contrib/operators/sftp_operator.py index cf6c6d254ecd1..def54c3fbadc9 100644 --- a/airflow/contrib/operators/sftp_operator.py +++ b/airflow/contrib/operators/sftp_operator.py @@ -86,20 +86,20 @@ def execute(self, context): if self.remote_host is not None: self.ssh_hook.remote_host = self.remote_host - ssh_client = self.ssh_hook.get_conn() - sftp_client = ssh_client.open_sftp() - if self.operation.lower() == SFTPOperation.GET: - file_msg = "from {0} to {1}".format(self.remote_filepath, - self.local_filepath) - self.log.debug("Starting to transfer %s", file_msg) - sftp_client.get(self.remote_filepath, self.local_filepath) - else: - file_msg = "from {0} to {1}".format(self.local_filepath, - self.remote_filepath) - self.log.debug("Starting to transfer file %s", file_msg) - sftp_client.put(self.local_filepath, - self.remote_filepath, - confirm=self.confirm) + with self.ssh_hook.get_conn() as ssh_client: + sftp_client = ssh_client.open_sftp() + if self.operation.lower() == SFTPOperation.GET: + file_msg = "from {0} to {1}".format(self.remote_filepath, + self.local_filepath) + self.log.debug("Starting to transfer %s", file_msg) + sftp_client.get(self.remote_filepath, self.local_filepath) + else: + file_msg = "from {0} to {1}".format(self.local_filepath, + self.remote_filepath) + self.log.debug("Starting to transfer file %s", file_msg) + sftp_client.put(self.local_filepath, + self.remote_filepath, + confirm=self.confirm) except Exception as e: raise AirflowException("Error while transferring {0}, error: {1}" diff --git a/airflow/contrib/operators/ssh_operator.py b/airflow/contrib/operators/ssh_operator.py index d246800953341..2e890f463ea88 100644 --- a/airflow/contrib/operators/ssh_operator.py +++ b/airflow/contrib/operators/ssh_operator.py @@ -77,79 +77,78 @@ def execute(self, context): if self.remote_host is not None: self.ssh_hook.remote_host = self.remote_host - ssh_client = self.ssh_hook.get_conn() - if not self.command: raise AirflowException("no command specified so nothing to execute here.") - # Auto apply tty when its required in case of sudo - get_pty = False - if self.command.startswith('sudo'): - get_pty = True - - # set timeout taken as params - stdin, stdout, stderr = ssh_client.exec_command(command=self.command, - get_pty=get_pty, - timeout=self.timeout - ) - # get channels - channel = stdout.channel - - # closing stdin - stdin.close() - channel.shutdown_write() - - agg_stdout = b'' - agg_stderr = b'' - - # capture any initial output in case channel is closed already - stdout_buffer_length = len(stdout.channel.in_buffer) - - if stdout_buffer_length > 0: - agg_stdout += stdout.channel.recv(stdout_buffer_length) - - # read from both stdout and stderr - while not channel.closed or \ - channel.recv_ready() or \ - channel.recv_stderr_ready(): - readq, _, _ = select([channel], [], [], self.timeout) - for c in readq: - if c.recv_ready(): - line = stdout.channel.recv(len(c.in_buffer)) - line = line - agg_stdout += line - self.log.info(line.decode('utf-8').strip('\n')) - if c.recv_stderr_ready(): - line = stderr.channel.recv_stderr(len(c.in_stderr_buffer)) - line = line - agg_stderr += line - self.log.warning(line.decode('utf-8').strip('\n')) - if stdout.channel.exit_status_ready()\ - and not stderr.channel.recv_stderr_ready()\ - and not stdout.channel.recv_ready(): - stdout.channel.shutdown_read() - stdout.channel.close() - break - - stdout.close() - stderr.close() - - exit_status = stdout.channel.recv_exit_status() - if exit_status is 0: - # returning output if do_xcom_push is set - if self.do_xcom_push: - enable_pickling = configuration.conf.getboolean( - 'core', 'enable_xcom_pickling' - ) - if enable_pickling: - return agg_stdout - else: - return b64encode(agg_stdout).decode('utf-8') - - else: - error_msg = agg_stderr.decode('utf-8') - raise AirflowException("error running cmd: {0}, error: {1}" - .format(self.command, error_msg)) + with self.ssh_hook.get_conn() as ssh_client: + # Auto apply tty when its required in case of sudo + get_pty = False + if self.command.startswith('sudo'): + get_pty = True + + # set timeout taken as params + stdin, stdout, stderr = ssh_client.exec_command(command=self.command, + get_pty=get_pty, + timeout=self.timeout + ) + # get channels + channel = stdout.channel + + # closing stdin + stdin.close() + channel.shutdown_write() + + agg_stdout = b'' + agg_stderr = b'' + + # capture any initial output in case channel is closed already + stdout_buffer_length = len(stdout.channel.in_buffer) + + if stdout_buffer_length > 0: + agg_stdout += stdout.channel.recv(stdout_buffer_length) + + # read from both stdout and stderr + while not channel.closed or \ + channel.recv_ready() or \ + channel.recv_stderr_ready(): + readq, _, _ = select([channel], [], [], self.timeout) + for c in readq: + if c.recv_ready(): + line = stdout.channel.recv(len(c.in_buffer)) + line = line + agg_stdout += line + self.log.info(line.decode('utf-8').strip('\n')) + if c.recv_stderr_ready(): + line = stderr.channel.recv_stderr(len(c.in_stderr_buffer)) + line = line + agg_stderr += line + self.log.warning(line.decode('utf-8').strip('\n')) + if stdout.channel.exit_status_ready()\ + and not stderr.channel.recv_stderr_ready()\ + and not stdout.channel.recv_ready(): + stdout.channel.shutdown_read() + stdout.channel.close() + break + + stdout.close() + stderr.close() + + exit_status = stdout.channel.recv_exit_status() + if exit_status is 0: + # returning output if do_xcom_push is set + if self.do_xcom_push: + enable_pickling = configuration.conf.getboolean( + 'core', 'enable_xcom_pickling' + ) + if enable_pickling: + return agg_stdout + else: + return b64encode(agg_stdout).decode('utf-8') + + else: + error_msg = agg_stderr.decode('utf-8') + raise AirflowException("error running cmd: {0}, error: {1}" + .format(self.command, error_msg)) except Exception as e: raise AirflowException("SSH operator error: {0}".format(str(e))) diff --git a/setup.py b/setup.py index b46a5a757b6b1..50af30944e414 100644 --- a/setup.py +++ b/setup.py @@ -208,7 +208,7 @@ def write_version(filename=os.path.join(*['airflow', mongo = ['pymongo>=3.6.0'] snowflake = ['snowflake-connector-python>=1.5.2', 'snowflake-sqlalchemy>=1.1.0'] -ssh = ['paramiko>=2.1.1', 'pysftp>=0.2.9'] +ssh = ['paramiko>=2.1.1', 'pysftp>=0.2.9', 'sshtunnel>=0.1.4,<0.2'] statsd = ['statsd>=3.0.1, <4.0'] vertica = ['vertica-python>=0.5.1'] webhdfs = ['hdfs[dataframe,avro,kerberos]>=2.0.4'] diff --git a/tests/contrib/hooks/test_ssh_hook.py b/tests/contrib/hooks/test_ssh_hook.py index b185a565c2f69..ad5621fe92357 100644 --- a/tests/contrib/hooks/test_ssh_hook.py +++ b/tests/contrib/hooks/test_ssh_hook.py @@ -7,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,6 +22,17 @@ from airflow.utils import db from airflow import models +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + +from airflow.contrib.hooks.ssh_hook import SSHHook + + HELLO_SERVER_CMD = """ import socket, sys listener = socket.socket() @@ -36,43 +47,90 @@ class SSHHookTest(unittest.TestCase): + def setUp(self): configuration.load_test_config() - from airflow.contrib.hooks.ssh_hook import SSHHook - self.hook = SSHHook(ssh_conn_id='ssh_default', keepalive_interval=10) - self.hook.no_host_key_check = True - def test_ssh_connection(self): - ssh_hook = self.hook.get_conn() - self.assertIsNotNone(ssh_hook) + @mock.patch('airflow.contrib.hooks.ssh_hook.paramiko.SSHClient') + def test_ssh_connection_with_password(self, ssh_mock): + hook = SSHHook(remote_host='remote_host', + port='port', + username='username', + password='password', + timeout=10, + key_file='fake.file') - def test_tunnel(self): - print("Setting up remote listener") - import subprocess - import socket + with hook.get_conn(): + ssh_mock.return_value.connect.assert_called_once_with( + hostname='remote_host', + username='username', + password='password', + key_filename='fake.file', + timeout=10, + compress=True, + port='port', + sock=None + ) - self.server_handle = subprocess.Popen(["python", "-c", HELLO_SERVER_CMD], - stdout=subprocess.PIPE) - print("Setting up tunnel") - with self.hook.create_tunnel(2135, 2134): - print("Tunnel up") - server_output = self.server_handle.stdout.read(5) - self.assertEqual(server_output, b"ready") - print("Connecting to server via tunnel") - s = socket.socket() - s.connect(("localhost", 2135)) - print("Receiving...",) - response = s.recv(5) - self.assertEqual(response, b"hello") - print("Closing connection") - s.close() - print("Waiting for listener...") - output, _ = self.server_handle.communicate() - self.assertEqual(self.server_handle.returncode, 0) - print("Closing tunnel") + @mock.patch('airflow.contrib.hooks.ssh_hook.paramiko.SSHClient') + def test_ssh_connection_without_password(self, ssh_mock): + hook = SSHHook(remote_host='remote_host', + port='port', + username='username', + timeout=10, + key_file='fake.file') + + with hook.get_conn(): + ssh_mock.return_value.connect.assert_called_once_with( + hostname='remote_host', + username='username', + key_filename='fake.file', + timeout=10, + compress=True, + port='port', + sock=None + ) + + @mock.patch('airflow.contrib.hooks.ssh_hook.SSHTunnelForwarder') + def test_tunnel_with_password(self, ssh_mock): + hook = SSHHook(remote_host='remote_host', + port='port', + username='username', + password='password', + timeout=10, + key_file='fake.file') + + with hook.get_tunnel(1234): + ssh_mock.assert_called_once_with('remote_host', + ssh_port='port', + ssh_username='username', + ssh_password='password', + ssh_pkey='fake.file', + ssh_proxy=None, + local_bind_address=('localhost', ), + remote_bind_address=('localhost', 1234), + logger=hook.log) + + @mock.patch('airflow.contrib.hooks.ssh_hook.SSHTunnelForwarder') + def test_tunnel_without_password(self, ssh_mock): + hook = SSHHook(remote_host='remote_host', + port='port', + username='username', + timeout=10, + key_file='fake.file') + + with hook.get_tunnel(1234): + ssh_mock.assert_called_once_with('remote_host', + ssh_port='port', + ssh_username='username', + ssh_pkey='fake.file', + ssh_proxy=None, + local_bind_address=('localhost', ), + remote_bind_address=('localhost', 1234), + host_pkey_directories=[], + logger=hook.log) def test_conn_with_extra_parameters(self): - from airflow.contrib.hooks.ssh_hook import SSHHook db.merge_conn( models.Connection(conn_id='ssh_with_extra', host='localhost', @@ -80,11 +138,41 @@ def test_conn_with_extra_parameters(self): extra='{"compress" : true, "no_host_key_check" : "true"}' ) ) - ssh_hook = SSHHook(ssh_conn_id='ssh_with_extra', keepalive_interval=10) - ssh_hook.get_conn() + ssh_hook = SSHHook(ssh_conn_id='ssh_with_extra') self.assertEqual(ssh_hook.compress, True) self.assertEqual(ssh_hook.no_host_key_check, True) + def test_ssh_connection(self): + hook = SSHHook(ssh_conn_id='ssh_default') + with hook.get_conn() as client: + (_, stdout, _) = client.exec_command('ls') + self.assertIsNotNone(stdout.read()) + + def test_ssh_connection_old_cm(self): + with SSHHook(ssh_conn_id='ssh_default') as hook: + client = hook.get_conn() + (_, stdout, _) = client.exec_command('ls') + self.assertIsNotNone(stdout.read()) + + def test_tunnel(self): + hook = SSHHook(ssh_conn_id='ssh_default') + + import subprocess + import socket + + server_handle = subprocess.Popen(["python", "-c", HELLO_SERVER_CMD], + stdout=subprocess.PIPE) + with hook.create_tunnel(2135, 2134): + server_output = server_handle.stdout.read(5) + self.assertEqual(server_output, b"ready") + s = socket.socket() + s.connect(("localhost", 2135)) + response = s.recv(5) + self.assertEqual(response, b"hello") + s.close() + output, _ = server_handle.communicate() + self.assertEqual(server_handle.returncode, 0) + if __name__ == '__main__': unittest.main()