diff --git a/airflow/contrib/operators/s3_to_sftp_operator.py b/airflow/contrib/operators/s3_to_sftp_operator.py new file mode 100644 index 0000000000000..7e02e97cd24e4 --- /dev/null +++ b/airflow/contrib/operators/s3_to_sftp_operator.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.models import BaseOperator +from airflow.hooks.S3_hook import S3Hook +from airflow.contrib.hooks.ssh_hook import SSHHook +from tempfile import NamedTemporaryFile +from urllib.parse import urlparse +from airflow.utils.decorators import apply_defaults + + +class S3ToSFTPOperator(BaseOperator): + """ + This operator enables the transferring of files from S3 to a SFTP server + :param sftp_conn_id: The sftp connection id. The name or + identifier for establishing a connection to the SFTP server. + :type sftp_conn_id: string + :param sftp_path: The sftp remote path. This is the specified + file path for uploading file to the SFTP server. + :type sftp_path: string + :param s3_conn_id: The s3 connnection id. The name or identifier for establishing + a connection to S3 + :type s3_conn_id: string + :param s3_bucket: The targeted s3 bucket. This is the S3 bucket + from where the file is downloaded. + :type s3_bucket: string + :param s3_key: The targeted s3 key. This is the specified file path + for downloading the file from S3. + :type s3_key: string + """ + + template_fields = ('s3_key', 'sftp_path') + + @apply_defaults + def __init__(self, + s3_bucket, + s3_key, + sftp_path, + sftp_conn_id='ssh_default', + s3_conn_id='aws_default', + *args, + **kwargs): + super(S3ToSFTPOperator, self).__init__(*args, **kwargs) + self.sftp_conn_id = sftp_conn_id + self.sftp_path = sftp_path + self.s3_bucket = s3_bucket + self.s3_key = s3_key + self.s3_conn_id = s3_conn_id + + @staticmethod + def get_s3_key(s3_key): + """This parses the correct format for S3 keys + regardless of how the S3 url is passed.""" + + parsed_s3_key = urlparse(s3_key) + return parsed_s3_key.path.lstrip('/') + + def execute(self, context): + self.s3_key = self.get_s3_key(self.s3_key) + ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id) + s3_hook = S3Hook(self.s3_conn_id) + + s3_client = s3_hook.get_conn() + sftp_client = ssh_hook.get_conn().open_sftp() + + with NamedTemporaryFile("w") as f: + s3_client.download_file(self.s3_bucket, self.s3_key, f.name) + sftp_client.put(f.name, self.sftp_path) diff --git a/airflow/contrib/operators/sftp_to_s3_operator.py b/airflow/contrib/operators/sftp_to_s3_operator.py new file mode 100644 index 0000000000000..b0ed1e16a3630 --- /dev/null +++ b/airflow/contrib/operators/sftp_to_s3_operator.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.models import BaseOperator +from airflow.hooks.S3_hook import S3Hook +from airflow.contrib.hooks.ssh_hook import SSHHook +from tempfile import NamedTemporaryFile +from urllib.parse import urlparse +from airflow.utils.decorators import apply_defaults + + +class SFTPToS3Operator(BaseOperator): + """ + This operator enables the transferring of files from a SFTP server to Amazon S3 + :param sftp_conn_id: The sftp connection id. The name or identifier for + establishing a connection to the SFTP server. + :type sftp_conn_id: string + :param sftp_path: The sftp remote path. This is the specified file + path for downloading the file from the SFTP server. + :type sftp_path: string + :param s3_conn_id: The s3 connnection id. The name or identifier for + establishing a connection to S3 + :type s3_conn_id: string + :param s3_bucket: The targeted s3 bucket. This is the S3 bucket + to where the file is uploaded. + :type s3_bucket: string + :param s3_key: The targeted s3 key. This is the specified path + for uploading the file to S3. + :type s3_key: string + """ + + template_fields = ('s3_key', 'sftp_path') + + @apply_defaults + def __init__(self, + s3_bucket, + s3_key, + sftp_path, + sftp_conn_id='ssh_default', + s3_conn_id='aws_default', + *args, + **kwargs): + super(SFTPToS3Operator, self).__init__(*args, **kwargs) + self.sftp_conn_id = sftp_conn_id + self.sftp_path = sftp_path + self.s3_bucket = s3_bucket + self.s3_key = s3_key + self.s3_conn_id = s3_conn_id + + @staticmethod + def get_s3_key(s3_key): + """This parses the correct format for S3 keys + regardless of how the S3 url is passed.""" + + parsed_s3_key = urlparse(s3_key) + return parsed_s3_key.path.lstrip('/') + + def execute(self, context): + self.s3_key = self.get_s3_key(self.s3_key) + ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id) + s3_hook = S3Hook(self.s3_conn_id) + + sftp_client = ssh_hook.get_conn().open_sftp() + + with NamedTemporaryFile("w") as f: + sftp_client.get(self.sftp_path, f.name) + + s3_hook.load_file( + filename=f.name, + key=self.s3_key, + bucket_name=self.s3_bucket, + replace=True + ) diff --git a/tests/contrib/operators/test_s3_to_sftp_operator.py b/tests/contrib/operators/test_s3_to_sftp_operator.py new file mode 100644 index 0000000000000..359c5eb912172 --- /dev/null +++ b/tests/contrib/operators/test_s3_to_sftp_operator.py @@ -0,0 +1,167 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +from airflow import configuration +from airflow import models +from airflow.contrib.operators.s3_to_sftp_operator import S3ToSFTPOperator +from airflow.contrib.operators.ssh_operator import SSHOperator +from airflow.models import DAG, TaskInstance +from airflow.settings import Session +from airflow.utils import timezone +from airflow.utils.timezone import datetime +import boto3 +from moto import mock_s3 + + +TASK_ID = 'test_s3_to_sftp' +BUCKET = 'test-s3-bucket' +S3_KEY = 'test/test_1_file.csv' +SFTP_PATH = '/tmp/remote_path.txt' +SFTP_CONN_ID = 'ssh_default' +S3_CONN_ID = 'aws_default' +LOCAL_FILE_PATH = '/tmp/test_s3_upload' + +SFTP_MOCK_FILE = 'test_sftp_file.csv' +S3_MOCK_FILES = 'test_1_file.csv' + +TEST_DAG_ID = 'unit_tests' +DEFAULT_DATE = datetime(2018, 1, 1) + + +def reset(dag_id=TEST_DAG_ID): + session = Session() + tis = session.query(models.TaskInstance).filter_by(dag_id=dag_id) + tis.delete() + session.commit() + session.close() + + +reset() + + +class S3ToSFTPOperatorTest(unittest.TestCase): + @mock_s3 + def setUp(self): + configuration.load_test_config() + from airflow.contrib.hooks.ssh_hook import SSHHook + from airflow.hooks.S3_hook import S3Hook + + hook = SSHHook(ssh_conn_id='ssh_default') + s3_hook = S3Hook('aws_default') + hook.no_host_key_check = True + args = { + 'owner': 'airflow', + 'start_date': DEFAULT_DATE, + 'provide_context': True + } + dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) + dag.schedule_interval = '@once' + + self.hook = hook + self.s3_hook = s3_hook + + self.ssh_client = self.hook.get_conn() + self.sftp_client = self.ssh_client.open_sftp() + + self.dag = dag + self.s3_bucket = BUCKET + self.sftp_path = SFTP_PATH + self.s3_key = S3_KEY + + @mock_s3 + def test_s3_to_sftp_operation(self): + # Setting + configuration.conf.set("core", "enable_xcom_pickling", "True") + test_remote_file_content = \ + "This is remote file content \n which is also multiline " \ + "another line here \n this is last line. EOF" + + # Test for creation of s3 bucket + conn = boto3.client('s3') + conn.create_bucket(Bucket=self.s3_bucket) + self.assertTrue((self.s3_hook.check_for_bucket(self.s3_bucket))) + + with open(LOCAL_FILE_PATH, 'w') as f: + f.write(test_remote_file_content) + self.s3_hook.load_file(LOCAL_FILE_PATH, self.s3_key, bucket_name=BUCKET) + + # Check if object was created in s3 + objects_in_dest_bucket = conn.list_objects(Bucket=self.s3_bucket, + Prefix=self.s3_key) + # there should be object found, and there should only be one object found + self.assertEqual(len(objects_in_dest_bucket['Contents']), 1) + + # the object found should be consistent with dest_key specified earlier + self.assertEqual(objects_in_dest_bucket['Contents'][0]['Key'], self.s3_key) + + # get remote file to local + run_task = S3ToSFTPOperator( + s3_bucket=BUCKET, + s3_key=S3_KEY, + sftp_path=SFTP_PATH, + sftp_conn_id=SFTP_CONN_ID, + s3_conn_id=S3_CONN_ID, + task_id=TASK_ID, + dag=self.dag + ) + self.assertIsNotNone(run_task) + + run_task.execute(None) + + # Check that the file is created remotely + check_file_task = SSHOperator( + task_id="test_check_file", + ssh_hook=self.hook, + command="cat {0}".format(self.sftp_path), + do_xcom_push=True, + dag=self.dag + ) + self.assertIsNotNone(check_file_task) + ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow()) + ti3.run() + self.assertEqual( + ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(), + test_remote_file_content.encode('utf-8')) + + # Clean up after finishing with test + conn.delete_object(Bucket=self.s3_bucket, Key=self.s3_key) + conn.delete_bucket(Bucket=self.s3_bucket) + self.assertFalse((self.s3_hook.check_for_bucket(self.s3_bucket))) + + def delete_remote_resource(self): + # check the remote file content + remove_file_task = SSHOperator( + task_id="test_check_file", + ssh_hook=self.hook, + command="rm {0}".format(self.sftp_path), + do_xcom_push=True, + dag=self.dag + ) + self.assertIsNotNone(remove_file_task) + ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow()) + ti3.run() + + def tearDown(self): + self.delete_remote_resource() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/contrib/operators/test_sftp_to_s3_operator.py b/tests/contrib/operators/test_sftp_to_s3_operator.py new file mode 100644 index 0000000000000..4be3a71d208f6 --- /dev/null +++ b/tests/contrib/operators/test_sftp_to_s3_operator.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +from airflow import configuration +from airflow import models +from airflow.contrib.operators.sftp_to_s3_operator import SFTPToS3Operator +from airflow.contrib.operators.ssh_operator import SSHOperator +from airflow.models import DAG, TaskInstance +from airflow.settings import Session +from airflow.utils import timezone +from airflow.utils.timezone import datetime +from airflow.contrib.hooks.ssh_hook import SSHHook +from airflow.hooks.S3_hook import S3Hook + +import boto3 +from moto import mock_s3 + +BUCKET = 'test-bucket' +S3_KEY = 'test/test_1_file.csv' +SFTP_PATH = '/tmp/remote_path.txt' +SFTP_CONN_ID = 'ssh_default' +S3_CONN_ID = 'aws_default' + +SFTP_MOCK_FILE = 'test_sftp_file.csv' +S3_MOCK_FILES = 'test_1_file.csv' + +TEST_DAG_ID = 'unit_tests' +DEFAULT_DATE = datetime(2018, 1, 1) + + +def reset(dag_id=TEST_DAG_ID): + session = Session() + tis = session.query(models.TaskInstance).filter_by(dag_id=dag_id) + tis.delete() + session.commit() + session.close() + + +reset() + + +class SFTPToS3OperatorTest(unittest.TestCase): + + @mock_s3 + def setUp(self): + configuration.load_test_config() + + hook = SSHHook(ssh_conn_id='ssh_default') + s3_hook = S3Hook('aws_default') + hook.no_host_key_check = True + args = { + 'owner': 'airflow', + 'start_date': DEFAULT_DATE, + 'provide_context': True + } + dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) + dag.schedule_interval = '@once' + + self.hook = hook + self.s3_hook = s3_hook + + self.ssh_client = self.hook.get_conn() + self.sftp_client = self.ssh_client.open_sftp() + + self.dag = dag + self.s3_bucket = BUCKET + self.sftp_path = SFTP_PATH + self.s3_key = S3_KEY + + @mock_s3 + def test_sftp_to_s3_operation(self): + # Setting + configuration.conf.set("core", "enable_xcom_pickling", "True") + test_remote_file_content = \ + "This is remote file content \n which is also multiline " \ + "another line here \n this is last line. EOF" + + # create a test file remotely + create_file_task = SSHOperator( + task_id="test_create_file", + ssh_hook=self.hook, + command="echo '{0}' > {1}".format(test_remote_file_content, + self.sftp_path), + do_xcom_push=True, + dag=self.dag + ) + self.assertIsNotNone(create_file_task) + ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) + ti1.run() + + # Test for creation of s3 bucket + conn = boto3.client('s3') + conn.create_bucket(Bucket=self.s3_bucket) + self.assertTrue((self.s3_hook.check_for_bucket(self.s3_bucket))) + + # get remote file to local + run_task = SFTPToS3Operator( + s3_bucket=BUCKET, + s3_key=S3_KEY, + sftp_path=SFTP_PATH, + sftp_conn_id=SFTP_CONN_ID, + s3_conn_id=S3_CONN_ID, + task_id='test_sftp_to_s3', + dag=self.dag + ) + self.assertIsNotNone(run_task) + + run_task.execute(None) + + # Check if object was created in s3 + objects_in_dest_bucket = conn.list_objects(Bucket=self.s3_bucket, + Prefix=self.s3_key) + # there should be object found, and there should only be one object found + self.assertEqual(len(objects_in_dest_bucket['Contents']), 1) + + # the object found should be consistent with dest_key specified earlier + self.assertEqual(objects_in_dest_bucket['Contents'][0]['Key'], self.s3_key) + + # Clean up after finishing with test + conn.delete_object(Bucket=self.s3_bucket, Key=self.s3_key) + conn.delete_bucket(Bucket=self.s3_bucket) + self.assertFalse((self.s3_hook.check_for_bucket(self.s3_bucket))) + + +if __name__ == '__main__': + unittest.main()