diff --git a/airflow/contrib/operators/sftp_operator.py b/airflow/contrib/operators/sftp_operator.py index 3c736c8b95101..a3b5c1f24492b 100644 --- a/airflow/contrib/operators/sftp_operator.py +++ b/airflow/contrib/operators/sftp_operator.py @@ -33,11 +33,15 @@ class SFTPOperator(BaseOperator): This operator uses ssh_hook to open sftp trasport channel that serve as basis for file transfer. - :param ssh_hook: predefined ssh_hook to use for remote execution + :param ssh_hook: predefined ssh_hook to use for remote execution. + Either `ssh_hook` or `ssh_conn_id` needs to be provided. :type ssh_hook: :class:`SSHHook` - :param ssh_conn_id: connection id from airflow Connections + :param ssh_conn_id: connection id from airflow Connections. + `ssh_conn_id` will be ingored if `ssh_hook` is provided. :type ssh_conn_id: str :param remote_host: remote host to connect (templated) + Nullable. If provided, it will replace the `remote_host` which was + defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`. :type remote_host: str :param local_filepath: local file path to get or put. (templated) :type local_filepath: str @@ -77,13 +81,21 @@ def __init__(self, def execute(self, context): file_msg = None try: - if self.ssh_conn_id and not self.ssh_hook: - self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) + if self.ssh_conn_id: + if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): + self.log.info("ssh_conn_id is ignored when ssh_hook is provided.") + else: + self.log.info("ssh_hook is not provided or invalid. " + + "Trying ssh_conn_id to create SSHHook.") + self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) if not self.ssh_hook: - raise AirflowException("can not operate without ssh_hook or ssh_conn_id") + raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.") if self.remote_host is not None: + self.log.info("remote_host is provided explicitly. " + + "It will replace the remote_host which was defined " + + "in ssh_hook or predefined in connection of ssh_conn_id.") self.ssh_hook.remote_host = self.remote_host with self.ssh_hook.get_conn() as ssh_client: diff --git a/airflow/contrib/operators/ssh_operator.py b/airflow/contrib/operators/ssh_operator.py index c0e8953d2c344..2bf342935d60c 100644 --- a/airflow/contrib/operators/ssh_operator.py +++ b/airflow/contrib/operators/ssh_operator.py @@ -31,11 +31,15 @@ class SSHOperator(BaseOperator): """ SSHOperator to execute commands on given remote host using the ssh_hook. - :param ssh_hook: predefined ssh_hook to use for remote execution + :param ssh_hook: predefined ssh_hook to use for remote execution. + Either `ssh_hook` or `ssh_conn_id` needs to be provided. :type ssh_hook: :class:`SSHHook` - :param ssh_conn_id: connection id from airflow Connections + :param ssh_conn_id: connection id from airflow Connections. + `ssh_conn_id` will be ingored if `ssh_hook` is provided. :type ssh_conn_id: str :param remote_host: remote host to connect (templated) + Nullable. If provided, it will replace the `remote_host` which was + defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`. :type remote_host: str :param command: command to execute on remote host. (templated) :type command: str @@ -68,14 +72,22 @@ def __init__(self, def execute(self, context): try: - if self.ssh_conn_id and not self.ssh_hook: - self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, - timeout=self.timeout) + if self.ssh_conn_id: + if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): + self.log.info("ssh_conn_id is ignored when ssh_hook is provided.") + else: + self.log.info("ssh_hook is not provided or invalid. " + + "Trying ssh_conn_id to create SSHHook.") + self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, + timeout=self.timeout) if not self.ssh_hook: raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.") if self.remote_host is not None: + self.log.info("remote_host is provided explicitly. " + + "It will replace the remote_host which was defined " + + "in ssh_hook or predefined in connection of ssh_conn_id.") self.ssh_hook.remote_host = self.remote_host if not self.command: diff --git a/tests/contrib/operators/test_sftp_operator.py b/tests/contrib/operators/test_sftp_operator.py index 01446a6fddd49..5770c1b940eb5 100644 --- a/tests/contrib/operators/test_sftp_operator.py +++ b/tests/contrib/operators/test_sftp_operator.py @@ -20,6 +20,7 @@ import os import unittest from base64 import b64encode +import six from airflow import configuration from airflow import models @@ -219,6 +220,71 @@ def test_json_file_transfer_get(self): self.assertEqual(content_received.strip(), test_remote_file_content.encode('utf-8').decode('utf-8')) + def test_arg_checking(self): + from airflow.exceptions import AirflowException + conn_id = "conn_id_for_testing" + os.environ['AIRFLOW_CONN_' + conn_id.upper()] = "ssh://test_id@localhost" + + # Exception should be raised if neither ssh_hook nor ssh_conn_id is provided + if six.PY2: + self.assertRaisesRegex = self.assertRaisesRegexp + with self.assertRaisesRegex(AirflowException, + "Cannot operate without ssh_hook or ssh_conn_id."): + task_0 = SFTPOperator( + task_id="test_sftp", + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + dag=self.dag + ) + task_0.execute(None) + + # if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook + task_1 = SFTPOperator( + task_id="test_sftp", + ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook + ssh_conn_id=conn_id, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + dag=self.dag + ) + try: + task_1.execute(None) + except Exception: + pass + self.assertEqual(task_1.ssh_hook.ssh_conn_id, conn_id) + + task_2 = SFTPOperator( + task_id="test_sftp", + ssh_conn_id=conn_id, # no ssh_hook provided + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + dag=self.dag + ) + try: + task_2.execute(None) + except Exception: + pass + self.assertEqual(task_2.ssh_hook.ssh_conn_id, conn_id) + + # if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id + task_3 = SFTPOperator( + task_id="test_sftp", + ssh_hook=self.hook, + ssh_conn_id=conn_id, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + dag=self.dag + ) + try: + task_3.execute(None) + except Exception: + pass + self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id) + def delete_local_resource(self): if os.path.exists(self.test_local_filepath): os.remove(self.test_local_filepath) @@ -226,11 +292,11 @@ def delete_local_resource(self): def delete_remote_resource(self): # check the remote file content remove_file_task = SSHOperator( - task_id="test_check_file", - ssh_hook=self.hook, - command="rm {0}".format(self.test_remote_filepath), - do_xcom_push=True, - dag=self.dag + task_id="test_check_file", + ssh_hook=self.hook, + command="rm {0}".format(self.test_remote_filepath), + do_xcom_push=True, + dag=self.dag ) self.assertIsNotNone(remove_file_task) ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow()) diff --git a/tests/contrib/operators/test_ssh_operator.py b/tests/contrib/operators/test_ssh_operator.py index 7ddd24b2ac2ca..1a2c788596671 100644 --- a/tests/contrib/operators/test_ssh_operator.py +++ b/tests/contrib/operators/test_ssh_operator.py @@ -19,6 +19,7 @@ import unittest from base64 import b64encode +import six from airflow import configuration from airflow import models @@ -148,6 +149,65 @@ def test_no_output_command(self): self.assertIsNotNone(ti.duration) self.assertEqual(ti.xcom_pull(task_ids='test', key='return_value'), b'') + def test_arg_checking(self): + import os + from airflow.exceptions import AirflowException + conn_id = "conn_id_for_testing" + TIMEOUT = 5 + os.environ['AIRFLOW_CONN_' + conn_id.upper()] = "ssh://test_id@localhost" + + # Exception should be raised if neither ssh_hook nor ssh_conn_id is provided + if six.PY2: + self.assertRaisesRegex = self.assertRaisesRegexp + with self.assertRaisesRegex(AirflowException, + "Cannot operate without ssh_hook or ssh_conn_id."): + task_0 = SSHOperator(task_id="test", command="echo -n airflow", + timeout=TIMEOUT, dag=self.dag) + task_0.execute(None) + + # if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook + task_1 = SSHOperator( + task_id="test_1", + ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook + ssh_conn_id=conn_id, + command="echo -n airflow", + timeout=TIMEOUT, + dag=self.dag + ) + try: + task_1.execute(None) + except Exception: + pass + self.assertEqual(task_1.ssh_hook.ssh_conn_id, conn_id) + + task_2 = SSHOperator( + task_id="test_2", + ssh_conn_id=conn_id, # no ssh_hook provided + command="echo -n airflow", + timeout=TIMEOUT, + dag=self.dag + ) + try: + task_2.execute(None) + except Exception: + pass + self.assertEqual(task_2.ssh_hook.ssh_conn_id, conn_id) + + # if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id + task_3 = SSHOperator( + task_id="test_3", + ssh_hook=self.hook, + ssh_conn_id=conn_id, + command="echo -n airflow", + timeout=TIMEOUT, + dag=self.dag + ) + try: + task_3.execute(None) + except Exception: + pass + self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id) + if __name__ == '__main__': unittest.main()