diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index 0f3c0451a42f4..0ef8b25ef3e1d 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -163,7 +163,7 @@ def check_for_bucket(self, bucket_name: Optional[str] = None) -> bool: return False @provide_bucket_name - def get_bucket(self, bucket_name: Optional[str] = None) -> str: + def get_bucket(self, bucket_name: Optional[str] = None) -> object: """ Returns a boto3.S3.Bucket object diff --git a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py index 575371a642f3f..64097075c2fc1 100644 --- a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py @@ -164,14 +164,16 @@ def execute(self, context: 'Context') -> List[str]: if files: for file in files: - file_bytes = hook.download(object_name=file, bucket_name=self.bucket) - - dest_key = self.dest_s3_key + file - self.log.info("Saving file to %s", dest_key) - - s3_hook.load_bytes( - file_bytes, key=dest_key, replace=self.replace, acl_policy=self.s3_acl_policy - ) + with hook.provide_file(object_name=file, bucket_name=self.bucket) as local_tmp_file: + dest_key = self.dest_s3_key + file + self.log.info("Saving file to %s", dest_key) + + s3_hook.load_file( + filename=local_tmp_file.name, + key=dest_key, + replace=self.replace, + acl_policy=self.s3_acl_policy, + ) self.log.info("All done, uploaded %d files to S3", len(files)) else: diff --git a/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py b/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py index eb13b3b2d704c..7dd19dc2be4c3 100644 --- a/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py @@ -17,6 +17,7 @@ # under the License. import unittest +from tempfile import NamedTemporaryFile from unittest import mock from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -40,164 +41,162 @@ class TestGCSToS3Operator(unittest.TestCase): # Test1: incremental behaviour (just some files missing) @mock_s3 - @mock.patch('airflow.providers.google.cloud.operators.gcs.GCSHook') @mock.patch('airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook') - def test_execute_incremental(self, mock_hook, mock_hook2): + def test_execute_incremental(self, mock_hook): mock_hook.return_value.list.return_value = MOCK_FILES - mock_hook.return_value.download.return_value = b"testing" - mock_hook2.return_value.list.return_value = MOCK_FILES + with NamedTemporaryFile() as f: + gcs_provide_file = mock_hook.return_value.provide_file + gcs_provide_file.return_value.__enter__.return_value.name = f.name - operator = GCSToS3Operator( - task_id=TASK_ID, - bucket=GCS_BUCKET, - prefix=PREFIX, - delimiter=DELIMITER, - dest_aws_conn_id="aws_default", - dest_s3_key=S3_BUCKET, - replace=False, - ) - # create dest bucket - hook = S3Hook(aws_conn_id='airflow_gcs_test') - bucket = hook.get_bucket('bucket') - bucket.create() - bucket.put_object(Key=MOCK_FILES[0], Body=b'testing') + operator = GCSToS3Operator( + task_id=TASK_ID, + bucket=GCS_BUCKET, + prefix=PREFIX, + delimiter=DELIMITER, + dest_aws_conn_id="aws_default", + dest_s3_key=S3_BUCKET, + replace=False, + ) + # create dest bucket + hook = S3Hook(aws_conn_id='airflow_gcs_test') + bucket = hook.get_bucket('bucket') + bucket.create() + bucket.put_object(Key=MOCK_FILES[0], Body=b'testing') - # we expect all except first file in MOCK_FILES to be uploaded - # and all the MOCK_FILES to be present at the S3 bucket - uploaded_files = operator.execute(None) - assert sorted(MOCK_FILES[1:]) == sorted(uploaded_files) - assert sorted(MOCK_FILES) == sorted(hook.list_keys('bucket', delimiter='/')) + # we expect all except first file in MOCK_FILES to be uploaded + # and all the MOCK_FILES to be present at the S3 bucket + uploaded_files = operator.execute(None) + assert sorted(MOCK_FILES[1:]) == sorted(uploaded_files) + assert sorted(MOCK_FILES) == sorted(hook.list_keys('bucket', delimiter='/')) # Test2: All the files are already in origin and destination without replace @mock_s3 - @mock.patch('airflow.providers.google.cloud.operators.gcs.GCSHook') @mock.patch('airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook') - def test_execute_without_replace(self, mock_hook, mock_hook2): + def test_execute_without_replace(self, mock_hook): mock_hook.return_value.list.return_value = MOCK_FILES - mock_hook.return_value.download.return_value = b"testing" - mock_hook2.return_value.list.return_value = MOCK_FILES + with NamedTemporaryFile() as f: + gcs_provide_file = mock_hook.return_value.provide_file + gcs_provide_file.return_value.__enter__.return_value.name = f.name - operator = GCSToS3Operator( - task_id=TASK_ID, - bucket=GCS_BUCKET, - prefix=PREFIX, - delimiter=DELIMITER, - dest_aws_conn_id="aws_default", - dest_s3_key=S3_BUCKET, - replace=False, - ) - # create dest bucket with all the files - hook = S3Hook(aws_conn_id='airflow_gcs_test') - bucket = hook.get_bucket('bucket') - bucket.create() - for mock_file in MOCK_FILES: - bucket.put_object(Key=mock_file, Body=b'testing') + operator = GCSToS3Operator( + task_id=TASK_ID, + bucket=GCS_BUCKET, + prefix=PREFIX, + delimiter=DELIMITER, + dest_aws_conn_id="aws_default", + dest_s3_key=S3_BUCKET, + replace=False, + ) + # create dest bucket with all the files + hook = S3Hook(aws_conn_id='airflow_gcs_test') + bucket = hook.get_bucket('bucket') + bucket.create() + for mock_file in MOCK_FILES: + bucket.put_object(Key=mock_file, Body=b'testing') - # we expect nothing to be uploaded - # and all the MOCK_FILES to be present at the S3 bucket - uploaded_files = operator.execute(None) - assert [] == uploaded_files - assert sorted(MOCK_FILES) == sorted(hook.list_keys('bucket', delimiter='/')) + # we expect nothing to be uploaded + # and all the MOCK_FILES to be present at the S3 bucket + uploaded_files = operator.execute(None) + assert [] == uploaded_files + assert sorted(MOCK_FILES) == sorted(hook.list_keys('bucket', delimiter='/')) # Test3: There are no files in destination bucket @mock_s3 - @mock.patch('airflow.providers.google.cloud.operators.gcs.GCSHook') @mock.patch('airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook') - def test_execute(self, mock_hook, mock_hook2): + def test_execute(self, mock_hook): mock_hook.return_value.list.return_value = MOCK_FILES - mock_hook.return_value.download.return_value = b"testing" - mock_hook2.return_value.list.return_value = MOCK_FILES + with NamedTemporaryFile() as f: + gcs_provide_file = mock_hook.return_value.provide_file + gcs_provide_file.return_value.__enter__.return_value.name = f.name - operator = GCSToS3Operator( - task_id=TASK_ID, - bucket=GCS_BUCKET, - prefix=PREFIX, - delimiter=DELIMITER, - dest_aws_conn_id="aws_default", - dest_s3_key=S3_BUCKET, - replace=False, - ) - # create dest bucket without files - hook = S3Hook(aws_conn_id='airflow_gcs_test') - bucket = hook.get_bucket('bucket') - bucket.create() + operator = GCSToS3Operator( + task_id=TASK_ID, + bucket=GCS_BUCKET, + prefix=PREFIX, + delimiter=DELIMITER, + dest_aws_conn_id="aws_default", + dest_s3_key=S3_BUCKET, + replace=False, + ) + # create dest bucket without files + hook = S3Hook(aws_conn_id='airflow_gcs_test') + bucket = hook.get_bucket('bucket') + bucket.create() - # we expect all MOCK_FILES to be uploaded - # and all MOCK_FILES to be present at the S3 bucket - uploaded_files = operator.execute(None) - assert sorted(MOCK_FILES) == sorted(uploaded_files) - assert sorted(MOCK_FILES) == sorted(hook.list_keys('bucket', delimiter='/')) + # we expect all MOCK_FILES to be uploaded + # and all MOCK_FILES to be present at the S3 bucket + uploaded_files = operator.execute(None) + assert sorted(MOCK_FILES) == sorted(uploaded_files) + assert sorted(MOCK_FILES) == sorted(hook.list_keys('bucket', delimiter='/')) # Test4: Destination and Origin are in sync but replace all files in destination @mock_s3 - @mock.patch('airflow.providers.google.cloud.operators.gcs.GCSHook') @mock.patch('airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook') - def test_execute_with_replace(self, mock_hook, mock_hook2): + def test_execute_with_replace(self, mock_hook): mock_hook.return_value.list.return_value = MOCK_FILES - mock_hook.return_value.download.return_value = b"testing" - mock_hook2.return_value.list.return_value = MOCK_FILES + with NamedTemporaryFile() as f: + gcs_provide_file = mock_hook.return_value.provide_file + gcs_provide_file.return_value.__enter__.return_value.name = f.name - operator = GCSToS3Operator( - task_id=TASK_ID, - bucket=GCS_BUCKET, - prefix=PREFIX, - delimiter=DELIMITER, - dest_aws_conn_id="aws_default", - dest_s3_key=S3_BUCKET, - replace=True, - ) - # create dest bucket with all the files - hook = S3Hook(aws_conn_id='airflow_gcs_test') - bucket = hook.get_bucket('bucket') - bucket.create() - for mock_file in MOCK_FILES: - bucket.put_object(Key=mock_file, Body=b'testing') + operator = GCSToS3Operator( + task_id=TASK_ID, + bucket=GCS_BUCKET, + prefix=PREFIX, + delimiter=DELIMITER, + dest_aws_conn_id="aws_default", + dest_s3_key=S3_BUCKET, + replace=True, + ) + # create dest bucket with all the files + hook = S3Hook(aws_conn_id='airflow_gcs_test') + bucket = hook.get_bucket('bucket') + bucket.create() + for mock_file in MOCK_FILES: + bucket.put_object(Key=mock_file, Body=b'testing') - # we expect all MOCK_FILES to be uploaded and replace the existing ones - # and all MOCK_FILES to be present at the S3 bucket - uploaded_files = operator.execute(None) - assert sorted(MOCK_FILES) == sorted(uploaded_files) - assert sorted(MOCK_FILES) == sorted(hook.list_keys('bucket', delimiter='/')) + # we expect all MOCK_FILES to be uploaded and replace the existing ones + # and all MOCK_FILES to be present at the S3 bucket + uploaded_files = operator.execute(None) + assert sorted(MOCK_FILES) == sorted(uploaded_files) + assert sorted(MOCK_FILES) == sorted(hook.list_keys('bucket', delimiter='/')) # Test5: Incremental sync with replace @mock_s3 - @mock.patch('airflow.providers.google.cloud.operators.gcs.GCSHook') @mock.patch('airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook') - def test_execute_incremental_with_replace(self, mock_hook, mock_hook2): + def test_execute_incremental_with_replace(self, mock_hook): mock_hook.return_value.list.return_value = MOCK_FILES - mock_hook.return_value.download.return_value = b"testing" - mock_hook2.return_value.list.return_value = MOCK_FILES + with NamedTemporaryFile() as f: + gcs_provide_file = mock_hook.return_value.provide_file + gcs_provide_file.return_value.__enter__.return_value.name = f.name - operator = GCSToS3Operator( - task_id=TASK_ID, - bucket=GCS_BUCKET, - prefix=PREFIX, - delimiter=DELIMITER, - dest_aws_conn_id="aws_default", - dest_s3_key=S3_BUCKET, - replace=True, - ) - # create dest bucket with just two files (the first two files in MOCK_FILES) - hook = S3Hook(aws_conn_id='airflow_gcs_test') - bucket = hook.get_bucket('bucket') - bucket.create() - for mock_file in MOCK_FILES[:2]: - bucket.put_object(Key=mock_file, Body=b'testing') + operator = GCSToS3Operator( + task_id=TASK_ID, + bucket=GCS_BUCKET, + prefix=PREFIX, + delimiter=DELIMITER, + dest_aws_conn_id="aws_default", + dest_s3_key=S3_BUCKET, + replace=True, + ) + # create dest bucket with just two files (the first two files in MOCK_FILES) + hook = S3Hook(aws_conn_id='airflow_gcs_test') + bucket = hook.get_bucket('bucket') + bucket.create() + for mock_file in MOCK_FILES[:2]: + bucket.put_object(Key=mock_file, Body=b'testing') - # we expect all the MOCK_FILES to be uploaded and replace the existing ones - # and all MOCK_FILES to be present at the S3 bucket - uploaded_files = operator.execute(None) - assert sorted(MOCK_FILES) == sorted(uploaded_files) - assert sorted(MOCK_FILES) == sorted(hook.list_keys('bucket', delimiter='/')) + # we expect all the MOCK_FILES to be uploaded and replace the existing ones + # and all MOCK_FILES to be present at the S3 bucket + uploaded_files = operator.execute(None) + assert sorted(MOCK_FILES) == sorted(uploaded_files) + assert sorted(MOCK_FILES) == sorted(hook.list_keys('bucket', delimiter='/')) @mock_s3 - @mock.patch('airflow.providers.google.cloud.operators.gcs.GCSHook') @mock.patch('airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook') @mock.patch('airflow.providers.amazon.aws.transfers.gcs_to_s3.S3Hook') - def test_execute_should_handle_with_default_dest_s3_extra_args(self, s3_mock_hook, mock_hook, mock_hook2): + def test_execute_should_handle_with_default_dest_s3_extra_args(self, s3_mock_hook, mock_hook): mock_hook.return_value.list.return_value = MOCK_FILES mock_hook.return_value.download.return_value = b"testing" - mock_hook2.return_value.list.return_value = MOCK_FILES s3_mock_hook.return_value = mock.Mock() s3_mock_hook.parse_s3_url.return_value = mock.Mock() @@ -214,61 +213,61 @@ def test_execute_should_handle_with_default_dest_s3_extra_args(self, s3_mock_hoo s3_mock_hook.assert_called_once_with(aws_conn_id='aws_default', extra_args={}, verify=None) @mock_s3 - @mock.patch('airflow.providers.google.cloud.operators.gcs.GCSHook') @mock.patch('airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook') @mock.patch('airflow.providers.amazon.aws.transfers.gcs_to_s3.S3Hook') - def test_execute_should_pass_dest_s3_extra_args_to_s3_hook(self, s3_mock_hook, mock_hook, mock_hook2): + def test_execute_should_pass_dest_s3_extra_args_to_s3_hook(self, s3_mock_hook, mock_hook): mock_hook.return_value.list.return_value = MOCK_FILES - mock_hook.return_value.download.return_value = b"testing" - mock_hook2.return_value.list.return_value = MOCK_FILES - s3_mock_hook.return_value = mock.Mock() - s3_mock_hook.parse_s3_url.return_value = mock.Mock() + with NamedTemporaryFile() as f: + gcs_provide_file = mock_hook.return_value.provide_file + gcs_provide_file.return_value.__enter__.return_value.name = f.name + s3_mock_hook.return_value = mock.Mock() + s3_mock_hook.parse_s3_url.return_value = mock.Mock() - operator = GCSToS3Operator( - task_id=TASK_ID, - bucket=GCS_BUCKET, - prefix=PREFIX, - delimiter=DELIMITER, - dest_aws_conn_id="aws_default", - dest_s3_key=S3_BUCKET, - replace=True, - dest_s3_extra_args={ - "ContentLanguage": "value", - }, - ) - operator.execute(None) - s3_mock_hook.assert_called_once_with( - aws_conn_id='aws_default', extra_args={'ContentLanguage': 'value'}, verify=None - ) + operator = GCSToS3Operator( + task_id=TASK_ID, + bucket=GCS_BUCKET, + prefix=PREFIX, + delimiter=DELIMITER, + dest_aws_conn_id="aws_default", + dest_s3_key=S3_BUCKET, + replace=True, + dest_s3_extra_args={ + "ContentLanguage": "value", + }, + ) + operator.execute(None) + s3_mock_hook.assert_called_once_with( + aws_conn_id='aws_default', extra_args={'ContentLanguage': 'value'}, verify=None + ) # Test6: s3_acl_policy parameter is set @mock_s3 - @mock.patch('airflow.providers.google.cloud.operators.gcs.GCSHook') @mock.patch('airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook') - @mock.patch('airflow.providers.amazon.aws.hooks.s3.S3Hook.load_bytes') - def test_execute_with_s3_acl_policy(self, mock_load_bytes, mock_gcs_hook, mock_gcs_hook2): + @mock.patch('airflow.providers.amazon.aws.hooks.s3.S3Hook.load_file') + def test_execute_with_s3_acl_policy(self, mock_load_file, mock_gcs_hook): mock_gcs_hook.return_value.list.return_value = MOCK_FILES - mock_gcs_hook.return_value.download.return_value = b"testing" - mock_gcs_hook2.return_value.list.return_value = MOCK_FILES + with NamedTemporaryFile() as f: + gcs_provide_file = mock_gcs_hook.return_value.provide_file + gcs_provide_file.return_value.__enter__.return_value.name = f.name - operator = GCSToS3Operator( - task_id=TASK_ID, - bucket=GCS_BUCKET, - prefix=PREFIX, - delimiter=DELIMITER, - dest_aws_conn_id="aws_default", - dest_s3_key=S3_BUCKET, - replace=False, - s3_acl_policy=S3_ACL_POLICY, - ) + operator = GCSToS3Operator( + task_id=TASK_ID, + bucket=GCS_BUCKET, + prefix=PREFIX, + delimiter=DELIMITER, + dest_aws_conn_id="aws_default", + dest_s3_key=S3_BUCKET, + replace=False, + s3_acl_policy=S3_ACL_POLICY, + ) - # Create dest bucket without files - hook = S3Hook(aws_conn_id='airflow_gcs_test') - bucket = hook.get_bucket('bucket') - bucket.create() + # Create dest bucket without files + hook = S3Hook(aws_conn_id='airflow_gcs_test') + bucket = hook.get_bucket('bucket') + bucket.create() - operator.execute(None) + operator.execute(None) - # Make sure the acl_policy parameter is passed to the upload method - _, kwargs = mock_load_bytes.call_args - assert kwargs['acl_policy'] == S3_ACL_POLICY + # Make sure the acl_policy parameter is passed to the upload method + _, kwargs = mock_load_file.call_args + assert kwargs['acl_policy'] == S3_ACL_POLICY