diff --git a/airflow/providers/amazon/aws/sensors/s3.py b/airflow/providers/amazon/aws/sensors/s3.py index 8e9e55aa52b98..2bda06639f5dc 100644 --- a/airflow/providers/amazon/aws/sensors/s3.py +++ b/airflow/providers/amazon/aws/sensors/s3.py @@ -77,7 +77,6 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.bucket_name = bucket_name self.bucket_key = bucket_key self.wildcard_match = wildcard_match @@ -85,23 +84,17 @@ def __init__( self.verify = verify self.hook: Optional[S3Hook] = None - def poke(self, context: 'Context'): - + def _resolve_bucket_and_key(self): + """If key is URI, parse bucket""" if self.bucket_name is None: - parsed_url = urlparse(self.bucket_key) - if parsed_url.netloc == '': - raise AirflowException('If key is a relative path from root, please provide a bucket_name') - self.bucket_name = parsed_url.netloc - self.bucket_key = parsed_url.path.lstrip('/') + self.bucket_name, self.bucket_key = S3Hook.parse_s3_url(self.bucket_key) else: parsed_url = urlparse(self.bucket_key) if parsed_url.scheme != '' or parsed_url.netloc != '': - raise AirflowException( - 'If bucket_name is provided, bucket_key' - ' should be relative path from root' - ' level, rather than a full s3:// url' - ) + raise AirflowException('If bucket_name provided, bucket_key must be relative path, not URI.') + def poke(self, context: 'Context'): + self._resolve_bucket_and_key() self.log.info('Poking for key : s3://%s/%s', self.bucket_name, self.bucket_key) if self.wildcard_match: return self.get_hook().check_for_wildcard_key(self.bucket_key, self.bucket_name) diff --git a/tests/providers/amazon/aws/sensors/test_s3_key.py b/tests/providers/amazon/aws/sensors/test_s3_key.py index ca441246060da..2f6237ebff642 100644 --- a/tests/providers/amazon/aws/sensors/test_s3_key.py +++ b/tests/providers/amazon/aws/sensors/test_s3_key.py @@ -58,9 +58,9 @@ def test_bucket_name_provided_and_bucket_key_is_s3_url(self): ['key', 'bucket', 'key', 'bucket'], ] ) - @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook') - def test_parse_bucket_key(self, key, bucket, parsed_key, parsed_bucket, mock_hook): - mock_hook.return_value.check_for_key.return_value = False + @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key') + def test_parse_bucket_key(self, key, bucket, parsed_key, parsed_bucket, mock_check): + mock_check.return_value = False op = S3KeySensor( task_id='s3_key_sensor', @@ -73,9 +73,9 @@ def test_parse_bucket_key(self, key, bucket, parsed_key, parsed_bucket, mock_hoo assert op.bucket_key == parsed_key assert op.bucket_name == parsed_bucket - @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook') - def test_parse_bucket_key_from_jinja(self, mock_hook): - mock_hook.return_value.check_for_key.return_value = False + @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key') + def test_parse_bucket_key_from_jinja(self, mock_check): + mock_check.return_value = False Variable.set("test_bucket_key", "s3://bucket/key") @@ -94,49 +94,46 @@ def test_parse_bucket_key_from_jinja(self, mock_hook): ti.dag_run = dag_run context = ti.get_template_context() ti.render_templates(context) - op.poke(None) assert op.bucket_key == "key" assert op.bucket_name == "bucket" - @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook') - def test_poke(self, mock_hook): + @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key') + def test_poke(self, mock_check): op = S3KeySensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file') - mock_check_for_key = mock_hook.return_value.check_for_key - mock_check_for_key.return_value = False - assert not op.poke(None) - mock_check_for_key.assert_called_once_with(op.bucket_key, op.bucket_name) + mock_check.return_value = False + assert op.poke(None) is False + mock_check.assert_called_once_with(op.bucket_key, op.bucket_name) - mock_hook.return_value.check_for_key.return_value = True - assert op.poke(None) + mock_check.return_value = True + assert op.poke(None) is True - @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook') - def test_poke_wildcard(self, mock_hook): + @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_wildcard_key') + def test_poke_wildcard(self, mock_check): op = S3KeySensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file', wildcard_match=True) - mock_check_for_wildcard_key = mock_hook.return_value.check_for_wildcard_key - mock_check_for_wildcard_key.return_value = False - assert not op.poke(None) - mock_check_for_wildcard_key.assert_called_once_with(op.bucket_key, op.bucket_name) + mock_check.return_value = False + assert op.poke(None) is False + mock_check.assert_called_once_with(op.bucket_key, op.bucket_name) - mock_check_for_wildcard_key.return_value = True - assert op.poke(None) + mock_check.return_value = True + assert op.poke(None) is True class TestS3KeySizeSensor(unittest.TestCase): @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key', return_value=False) def test_poke_check_for_key_false(self, mock_check_for_key): op = S3KeySizeSensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file') - assert not op.poke(None) + assert op.poke(None) is False mock_check_for_key.assert_called_once_with(op.bucket_key, op.bucket_name) @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3KeySizeSensor.get_files', return_value=[]) @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key', return_value=True) def test_poke_get_files_false(self, mock_check_for_key, mock_get_files): op = S3KeySizeSensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file') - assert not op.poke(None) + assert op.poke(None) is False mock_check_for_key.assert_called_once_with(op.bucket_key, op.bucket_name) mock_get_files.assert_called_once_with(s3_hook=op.get_hook()) @@ -150,28 +147,25 @@ def test_poke_get_files_false(self, mock_check_for_key, mock_get_files): [{"Contents": [{"Size": 10}, {"Size": 10}]}, True], ] ) - @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook') - def test_poke(self, paginate_return_value, poke_return_value, mock_hook): + @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.get_conn') + @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key') + def test_poke(self, paginate_return_value, poke_return_value, mock_check, mock_get_conn): op = S3KeySizeSensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file') - mock_check_for_key = mock_hook.return_value.check_for_key - mock_hook.return_value.check_for_key.return_value = True + mock_check.return_value = True mock_paginator = mock.Mock() mock_paginator.paginate.return_value = [] - mock_conn = mock.Mock() - mock_conn.return_value.get_paginator.return_value = mock_paginator - mock_hook.return_value.get_conn = mock_conn + mock_get_conn.return_value.get_paginator.return_value = mock_paginator mock_paginator.paginate.return_value = [paginate_return_value] assert op.poke(None) is poke_return_value - mock_check_for_key.assert_called_once_with(op.bucket_key, op.bucket_name) + mock_check.assert_called_once_with(op.bucket_key, op.bucket_name) @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3KeySizeSensor.get_files', return_value=[]) - @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook') - def test_poke_wildcard(self, mock_hook, mock_get_files): + @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_wildcard_key') + def test_poke_wildcard(self, mock_check, mock_get_files): op = S3KeySizeSensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file', wildcard_match=True) - mock_check_for_wildcard_key = mock_hook.return_value.check_for_wildcard_key - mock_check_for_wildcard_key.return_value = False - assert not op.poke(None) - mock_check_for_wildcard_key.assert_called_once_with(op.bucket_key, op.bucket_name) + mock_check.return_value = False + assert op.poke(None) is False + mock_check.assert_called_once_with(op.bucket_key, op.bucket_name)