From 1b07848ea2e713b7637daee1402a7b689734799b Mon Sep 17 00:00:00 2001 From: Saverio Guzzo Date: Thu, 9 Dec 2021 12:55:06 +0100 Subject: [PATCH 1/2] dummy solution with two optional parameters --- airflow/providers/sftp/hooks/sftp.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/airflow/providers/sftp/hooks/sftp.py b/airflow/providers/sftp/hooks/sftp.py index ec1f860bf2371..881407057ff82 100644 --- a/airflow/providers/sftp/hooks/sftp.py +++ b/airflow/providers/sftp/hooks/sftp.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. """This module contains SFTP hook.""" +import warnings import datetime import stat from typing import Dict, List, Optional, Tuple @@ -32,7 +33,7 @@ class SFTPHook(SSHHook): This hook is inherited from SSH hook. Please refer to SSH hook for the input arguments. - Interact with SFTP. Aims to be interchangeable with FTPHook. + Interact with SFTP. :Pitfalls:: @@ -50,7 +51,7 @@ class SFTPHook(SSHHook): :type sftp_conn_id: str """ - conn_name_attr = 'ftp_conn_id' + conn_name_attr = 'ssh_conn_id' default_conn_name = 'sftp_default' conn_type = 'sftp' hook_name = 'SFTP' @@ -64,8 +65,22 @@ def get_ui_field_behaviour() -> Dict: }, } - def __init__(self, ftp_conn_id: str = 'sftp_default', *args, **kwargs) -> None: - kwargs['ssh_conn_id'] = ftp_conn_id + def __init__( + self, + ssh_conn_id: Optional[str] = 'sftp_default', + ftp_conn_id: Optional[str] = 'sftp_default', + *args, **kwargs + ) -> None: + + if ftp_conn_id: + warnings.warn( + 'Parameter `ftp_conn_id` is deprecated.' + 'Please use `ssh_conn_id` instead.', + DeprecationWarning, + stacklevel=2, + ) + kwargs['ssh_conn_id'] = ftp_conn_id + self.ssh_conn_id = ssh_conn_id super().__init__(*args, **kwargs) self.conn = None @@ -82,7 +97,6 @@ def __init__(self, ftp_conn_id: str = 'sftp_default', *args, **kwargs) -> None: # For backward compatibility # TODO: remove in Airflow 2.1 - import warnings if 'private_key_pass' in extra_options: warnings.warn( From 5c7f41369017495e3f316557d749d45860cf78c5 Mon Sep 17 00:00:00 2001 From: Saverio Guzzo Date: Tue, 14 Dec 2021 11:25:43 +0100 Subject: [PATCH 2/2] added tests for both versions of hooks it should be possible to create tests for the two versions of the hook constructor using 'parametrized' --- airflow/providers/sftp/hooks/sftp.py | 18 +- .../sftp/hooks/test_sftp_outdated.py | 340 ++++++++++++++++++ 2 files changed, 352 insertions(+), 6 deletions(-) create mode 100644 tests/providers/sftp/hooks/test_sftp_outdated.py diff --git a/airflow/providers/sftp/hooks/sftp.py b/airflow/providers/sftp/hooks/sftp.py index 881407057ff82..cc75b0a161ddf 100644 --- a/airflow/providers/sftp/hooks/sftp.py +++ b/airflow/providers/sftp/hooks/sftp.py @@ -16,9 +16,9 @@ # specific language governing permissions and limitations # under the License. """This module contains SFTP hook.""" -import warnings import datetime import stat +import warnings from typing import Dict, List, Optional, Tuple import pysftp @@ -47,8 +47,14 @@ class SFTPHook(SSHHook): Errors that may occur throughout but should be handled downstream. - :param sftp_conn_id: The :ref:`sftp connection id` - :type sftp_conn_id: str + For consistency reasons with SSHHook, the preferred parameter is "ssh_conn_id". + Please note that it is still possible to use the parameter "ftp_conn_id" + to initialize the hook, but it will be removed in future Airflow versions. + + :param ssh_conn_id: The :ref:`sftp connection id` + :type ssh_conn_id: str + :param ftp_conn_id (Outdated): The :ref:`sftp connection id` + :type ftp_conn_id: str """ conn_name_attr = 'ssh_conn_id' @@ -69,13 +75,13 @@ def __init__( self, ssh_conn_id: Optional[str] = 'sftp_default', ftp_conn_id: Optional[str] = 'sftp_default', - *args, **kwargs + *args, + **kwargs, ) -> None: if ftp_conn_id: warnings.warn( - 'Parameter `ftp_conn_id` is deprecated.' - 'Please use `ssh_conn_id` instead.', + 'Parameter `ftp_conn_id` is deprecated.' 'Please use `ssh_conn_id` instead.', DeprecationWarning, stacklevel=2, ) diff --git a/tests/providers/sftp/hooks/test_sftp_outdated.py b/tests/providers/sftp/hooks/test_sftp_outdated.py new file mode 100644 index 0000000000000..72ba9def200c8 --- /dev/null +++ b/tests/providers/sftp/hooks/test_sftp_outdated.py @@ -0,0 +1,340 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +import os +import shutil +import unittest +from io import StringIO +from unittest import mock + +import paramiko +import pysftp +from parameterized import parameterized + +from airflow.models import Connection +from airflow.providers.sftp.hooks.sftp import SFTPHook +from airflow.utils.session import provide_session + + +def generate_host_key(pkey: paramiko.PKey): + key_fh = StringIO() + pkey.write_private_key(key_fh) + key_fh.seek(0) + key_obj = paramiko.RSAKey(file_obj=key_fh) + return key_obj.get_base64() + + +TMP_PATH = '/tmp' +TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir' +SUB_DIR = "sub_dir" +TMP_FILE_FOR_TESTS = 'test_file.txt' + +SFTP_CONNECTION_USER = "root" + +TEST_PKEY = paramiko.RSAKey.generate(4096) +TEST_HOST_KEY = generate_host_key(pkey=TEST_PKEY) +TEST_KEY_FILE = "~/.ssh/id_rsa" + + +class TestSFTPHook(unittest.TestCase): + @provide_session + def update_connection(self, login, session=None): + connection = session.query(Connection).filter(Connection.conn_id == "sftp_default").first() + old_login = connection.login + connection.login = login + session.commit() + return old_login + + def setUp(self): + self.old_login = self.update_connection(SFTP_CONNECTION_USER) + self.hook = SFTPHook(ftp_conn_id='sftp_default') + os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)) + + with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file: + file.write('Test file') + with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), 'a') as file: + file.write('Test file') + + def test_get_conn(self): + output = self.hook.get_conn() + assert isinstance(output, pysftp.Connection) + + def test_close_conn(self): + self.hook.conn = self.hook.get_conn() + assert self.hook.conn is not None + self.hook.close_conn() + assert self.hook.conn is None + + def test_describe_directory(self): + output = self.hook.describe_directory(TMP_PATH) + assert TMP_DIR_FOR_TESTS in output + + def test_list_directory(self): + output = self.hook.list_directory(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + assert output == [SUB_DIR] + + def test_create_and_delete_directory(self): + new_dir_name = 'new_dir' + self.hook.create_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name)) + output = self.hook.describe_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + assert new_dir_name in output + self.hook.delete_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name)) + output = self.hook.describe_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + assert new_dir_name not in output + + def test_create_and_delete_directories(self): + base_dir = "base_dir" + sub_dir = "sub_dir" + new_dir_path = os.path.join(base_dir, sub_dir) + self.hook.create_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path)) + output = self.hook.describe_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + assert base_dir in output + output = self.hook.describe_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir)) + assert sub_dir in output + self.hook.delete_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path)) + self.hook.delete_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir)) + output = self.hook.describe_directory(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + assert new_dir_path not in output + assert base_dir not in output + + def test_store_retrieve_and_delete_file(self): + self.hook.store_file( + remote_full_path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), + local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), + ) + output = self.hook.list_directory(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + assert output == [SUB_DIR, TMP_FILE_FOR_TESTS] + retrieved_file_name = 'retrieved.txt' + self.hook.retrieve_file( + remote_full_path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), + local_full_path=os.path.join(TMP_PATH, retrieved_file_name), + ) + assert retrieved_file_name in os.listdir(TMP_PATH) + os.remove(os.path.join(TMP_PATH, retrieved_file_name)) + self.hook.delete_file(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS)) + output = self.hook.list_directory(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + assert output == [SUB_DIR] + + def test_get_mod_time(self): + self.hook.store_file( + remote_full_path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), + local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), + ) + output = self.hook.get_mod_time(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS)) + assert len(output) == 14 + + @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') + def test_no_host_key_check_default(self, get_connection): + connection = Connection(login='login', host='host') + get_connection.return_value = connection + hook = SFTPHook() + assert hook.no_host_key_check is False + + @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') + def test_no_host_key_check_enabled(self, get_connection): + connection = Connection(login='login', host='host', extra='{"no_host_key_check": true}') + + get_connection.return_value = connection + hook = SFTPHook() + assert hook.no_host_key_check is True + + @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') + def test_no_host_key_check_disabled(self, get_connection): + connection = Connection(login='login', host='host', extra='{"no_host_key_check": false}') + + get_connection.return_value = connection + hook = SFTPHook() + assert hook.no_host_key_check is False + + @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') + def test_ciphers(self, get_connection): + connection = Connection(login='login', host='host', extra='{"ciphers": ["A", "B", "C"]}') + + get_connection.return_value = connection + hook = SFTPHook() + assert hook.ciphers == ["A", "B", "C"] + + @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') + def test_no_host_key_check_disabled_for_all_but_true(self, get_connection): + connection = Connection(login='login', host='host', extra='{"no_host_key_check": "foo"}') + + get_connection.return_value = connection + hook = SFTPHook() + assert hook.no_host_key_check is False + + @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') + def test_no_host_key_check_ignore(self, get_connection): + connection = Connection(login='login', host='host', extra='{"ignore_hostkey_verification": true}') + + get_connection.return_value = connection + hook = SFTPHook() + assert hook.no_host_key_check is True + + @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') + def test_no_host_key_check_no_ignore(self, get_connection): + connection = Connection(login='login', host='host', extra='{"ignore_hostkey_verification": false}') + + get_connection.return_value = connection + hook = SFTPHook() + assert hook.no_host_key_check is False + + @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') + def test_host_key_default(self, get_connection): + connection = Connection(login='login', host='host') + get_connection.return_value = connection + hook = SFTPHook() + assert hook.host_key is None + + @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') + def test_host_key(self, get_connection): + connection = Connection( + login='login', + host='host', + extra=json.dumps({"host_key": TEST_HOST_KEY}), + ) + get_connection.return_value = connection + hook = SFTPHook() + assert hook.host_key.get_base64() == TEST_HOST_KEY + + @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') + def test_host_key_with_type(self, get_connection): + connection = Connection( + login='login', + host='host', + extra=json.dumps({"host_key": "ssh-rsa " + TEST_HOST_KEY}), + ) + get_connection.return_value = connection + hook = SFTPHook() + assert hook.host_key.get_base64() == TEST_HOST_KEY + + @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') + def test_host_key_with_no_host_key_check(self, get_connection): + connection = Connection(login='login', host='host', extra=json.dumps({"host_key": TEST_HOST_KEY})) + get_connection.return_value = connection + hook = SFTPHook() + assert hook.host_key is not None + + @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') + def test_key_content_as_str(self, get_connection): + file_obj = StringIO() + TEST_PKEY.write_private_key(file_obj) + file_obj.seek(0) + key_content_str = file_obj.read() + + connection = Connection( + login='login', + host='host', + extra=json.dumps( + { + "private_key": key_content_str, + } + ), + ) + get_connection.return_value = connection + hook = SFTPHook() + assert hook.pkey == TEST_PKEY + assert hook.key_file is None + + @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') + def test_key_file(self, get_connection): + connection = Connection( + login='login', + host='host', + extra=json.dumps( + { + "key_file": TEST_KEY_FILE, + } + ), + ) + get_connection.return_value = connection + hook = SFTPHook() + assert hook.key_file == TEST_KEY_FILE + + @parameterized.expand( + [ + (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True), + (os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), True), + (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS + "abc"), False), + (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, "abc"), False), + ] + ) + def test_path_exists(self, path, exists): + result = self.hook.path_exists(path) + assert result == exists + + @parameterized.expand( + [ + ("test/path/file.bin", None, None, True), + ("test/path/file.bin", "test", None, True), + ("test/path/file.bin", "test/", None, True), + ("test/path/file.bin", None, "bin", True), + ("test/path/file.bin", "test", "bin", True), + ("test/path/file.bin", "test/", "file.bin", True), + ("test/path/file.bin", None, "file.bin", True), + ("test/path/file.bin", "diff", None, False), + ("test/path/file.bin", "test//", None, False), + ("test/path/file.bin", None, ".txt", False), + ("test/path/file.bin", "diff", ".txt", False), + ] + ) + def test_path_match(self, path, prefix, delimiter, match): + result = self.hook._is_path_match(path=path, prefix=prefix, delimiter=delimiter) + assert result == match + + def test_get_tree_map(self): + tree_map = self.hook.get_tree_map(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + files, dirs, unknowns = tree_map + + assert files == [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS)] + assert dirs == [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)] + assert unknowns == [] + + @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') + def test_connection_failure(self, mock_get_connection): + connection = Connection( + login='login', + host='host', + ) + mock_get_connection.return_value = connection + with mock.patch.object(SFTPHook, 'get_conn') as get_conn: + type(get_conn.return_value).pwd = mock.PropertyMock(side_effect=Exception('Connection Error')) + + hook = SFTPHook() + status, msg = hook.test_connection() + assert status is False + assert msg == 'Connection Error' + + @mock.patch('airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection') + def test_connection_success(self, mock_get_connection): + connection = Connection( + login='login', + host='host', + ) + mock_get_connection.return_value = connection + + with mock.patch.object(SFTPHook, 'get_conn') as get_conn: + get_conn.return_value.pwd = '/home/someuser' + hook = SFTPHook() + status, msg = hook.test_connection() + assert status is True + assert msg == 'Connection successfully tested' + + def tearDown(self): + shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)) + self.update_connection(self.old_login)