From 25a9855a51e121c9b00765539749e34e4c54ee87 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Thu, 10 Feb 2022 09:44:49 -0800 Subject: [PATCH 1/5] S3KeySensor should parse bucket and key outside of poke Rather than relying on the poke method to parse the attrs we should either do it in `__init__` or in a cached property. Since it's a pretty insignifigant computation let's just do it in `__init__`. And anyway, if the params are bad then best to get a warning before deploying your code. --- airflow/providers/amazon/aws/sensors/s3.py | 45 ++++++++++++---------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/airflow/providers/amazon/aws/sensors/s3.py b/airflow/providers/amazon/aws/sensors/s3.py index 8e9e55aa52b98..fe82d32ed20ed 100644 --- a/airflow/providers/amazon/aws/sensors/s3.py +++ b/airflow/providers/amazon/aws/sensors/s3.py @@ -21,7 +21,7 @@ import re import sys from datetime import datetime -from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Set, Union +from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Set, Tuple, Union from urllib.parse import urlparse if TYPE_CHECKING: @@ -77,31 +77,36 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - - self.bucket_name = bucket_name - self.bucket_key = bucket_key self.wildcard_match = wildcard_match self.aws_conn_id = aws_conn_id self.verify = verify self.hook: Optional[S3Hook] = None - - def poke(self, context: 'Context'): - - if self.bucket_name is None: - parsed_url = urlparse(self.bucket_key) - if parsed_url.netloc == '': - raise AirflowException('If key is a relative path from root, please provide a bucket_name') - self.bucket_name = parsed_url.netloc - self.bucket_key = parsed_url.path.lstrip('/') + if bucket_name is None: + self.bucket_name, self.bucket_key = self._parse_object_from_uri(bucket_key) else: - parsed_url = urlparse(self.bucket_key) - if parsed_url.scheme != '' or parsed_url.netloc != '': - raise AirflowException( - 'If bucket_name is provided, bucket_key' - ' should be relative path from root' - ' level, rather than a full s3:// url' - ) + self.bucket_name = bucket_name + self._validate_key(bucket_key) + self.bucket_key = bucket_key + + @staticmethod + def _parse_object_from_uri(uri) -> Tuple[str, str]: + parsed = urlparse(uri) + if parsed.netloc == '': + raise AirflowException('If key is a relative path from root, please provide a bucket_name') + bucket = parsed.netloc + key = parsed.path.lstrip('/') + return bucket, key + + @staticmethod + def _validate_key(key): + parsed = urlparse(key) + if parsed.scheme != '' or parsed.netloc != '': + raise AirflowException( + 'If bucket_name is provided, bucket_key should be relative path from root level, ' + 'rather than a full s3:// url' + ) + def poke(self, context: 'Context'): self.log.info('Poking for key : s3://%s/%s', self.bucket_name, self.bucket_key) if self.wildcard_match: return self.get_hook().check_for_wildcard_key(self.bucket_key, self.bucket_name) From ba4f318b37abfdf726ff375bb136e8f6ce102881 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Thu, 10 Feb 2022 15:35:40 -0800 Subject: [PATCH 2/5] must resolve only after template rendering --- airflow/providers/amazon/aws/sensors/s3.py | 44 ++++++-------- .../amazon/aws/sensors/test_s3_key.py | 58 +++++++++---------- 2 files changed, 45 insertions(+), 57 deletions(-) diff --git a/airflow/providers/amazon/aws/sensors/s3.py b/airflow/providers/amazon/aws/sensors/s3.py index fe82d32ed20ed..4b35f68eb24d4 100644 --- a/airflow/providers/amazon/aws/sensors/s3.py +++ b/airflow/providers/amazon/aws/sensors/s3.py @@ -21,7 +21,7 @@ import re import sys from datetime import datetime -from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Set, Tuple, Union +from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Set, Union from urllib.parse import urlparse if TYPE_CHECKING: @@ -77,39 +77,33 @@ def __init__( **kwargs, ): super().__init__(**kwargs) + self.bucket_key = bucket_key + self.bucket_name = bucket_name self.wildcard_match = wildcard_match self.aws_conn_id = aws_conn_id self.verify = verify self.hook: Optional[S3Hook] = None - if bucket_name is None: - self.bucket_name, self.bucket_key = self._parse_object_from_uri(bucket_key) + + def _resolve_bucket_and_key(self): + """ + If key is URI we should parse the bucket and leave only key portion under + the ``bucket_key`` attr. + """ + if self.bucket_name is None: + self.bucket_name, self.bucket_key = S3Hook.parse_s3_url(self.bucket_key) else: - self.bucket_name = bucket_name - self._validate_key(bucket_key) - self.bucket_key = bucket_key - - @staticmethod - def _parse_object_from_uri(uri) -> Tuple[str, str]: - parsed = urlparse(uri) - if parsed.netloc == '': - raise AirflowException('If key is a relative path from root, please provide a bucket_name') - bucket = parsed.netloc - key = parsed.path.lstrip('/') - return bucket, key - - @staticmethod - def _validate_key(key): - parsed = urlparse(key) - if parsed.scheme != '' or parsed.netloc != '': - raise AirflowException( - 'If bucket_name is provided, bucket_key should be relative path from root level, ' - 'rather than a full s3:// url' - ) + parsed = urlparse(self.bucket_key) + if parsed.scheme != '' or parsed.netloc != '': + raise AirflowException( + 'If bucket_name is provided, bucket_key should be relative path from root level, ' + 'rather than a full s3:// url' + ) def poke(self, context: 'Context'): + self._resolve_bucket_and_key() self.log.info('Poking for key : s3://%s/%s', self.bucket_name, self.bucket_key) if self.wildcard_match: - return self.get_hook().check_for_wildcard_key(self.bucket_key, self.bucket_name) + return self.get_hook().f(self.bucket_key, self.bucket_name) return self.get_hook().check_for_key(self.bucket_key, self.bucket_name) def get_hook(self) -> S3Hook: diff --git a/tests/providers/amazon/aws/sensors/test_s3_key.py b/tests/providers/amazon/aws/sensors/test_s3_key.py index ca441246060da..a12fed3eae73e 100644 --- a/tests/providers/amazon/aws/sensors/test_s3_key.py +++ b/tests/providers/amazon/aws/sensors/test_s3_key.py @@ -58,9 +58,9 @@ def test_bucket_name_provided_and_bucket_key_is_s3_url(self): ['key', 'bucket', 'key', 'bucket'], ] ) - @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook') - def test_parse_bucket_key(self, key, bucket, parsed_key, parsed_bucket, mock_hook): - mock_hook.return_value.check_for_key.return_value = False + @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key') + def test_parse_bucket_key(self, key, bucket, parsed_key, parsed_bucket, mock_check): + mock_check.return_value = False op = S3KeySensor( task_id='s3_key_sensor', @@ -73,9 +73,9 @@ def test_parse_bucket_key(self, key, bucket, parsed_key, parsed_bucket, mock_hoo assert op.bucket_key == parsed_key assert op.bucket_name == parsed_bucket - @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook') - def test_parse_bucket_key_from_jinja(self, mock_hook): - mock_hook.return_value.check_for_key.return_value = False + @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key') + def test_parse_bucket_key_from_jinja(self, mock_check): + mock_check.return_value = False Variable.set("test_bucket_key", "s3://bucket/key") @@ -94,34 +94,31 @@ def test_parse_bucket_key_from_jinja(self, mock_hook): ti.dag_run = dag_run context = ti.get_template_context() ti.render_templates(context) - op.poke(None) assert op.bucket_key == "key" assert op.bucket_name == "bucket" - @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook') - def test_poke(self, mock_hook): + @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key') + def test_poke(self, mock_check): op = S3KeySensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file') - mock_check_for_key = mock_hook.return_value.check_for_key - mock_check_for_key.return_value = False + mock_check.return_value = False assert not op.poke(None) - mock_check_for_key.assert_called_once_with(op.bucket_key, op.bucket_name) + mock_check.assert_called_once_with(op.bucket_key, op.bucket_name) - mock_hook.return_value.check_for_key.return_value = True + mock_check.return_value = True assert op.poke(None) - @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook') - def test_poke_wildcard(self, mock_hook): + @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_wildcard_key') + def test_poke_wildcard(self, mock_check): op = S3KeySensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file', wildcard_match=True) - mock_check_for_wildcard_key = mock_hook.return_value.check_for_wildcard_key - mock_check_for_wildcard_key.return_value = False + mock_check.return_value = False assert not op.poke(None) - mock_check_for_wildcard_key.assert_called_once_with(op.bucket_key, op.bucket_name) + mock_check.assert_called_once_with(op.bucket_key, op.bucket_name) - mock_check_for_wildcard_key.return_value = True + mock_check.return_value = True assert op.poke(None) @@ -150,28 +147,25 @@ def test_poke_get_files_false(self, mock_check_for_key, mock_get_files): [{"Contents": [{"Size": 10}, {"Size": 10}]}, True], ] ) - @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook') - def test_poke(self, paginate_return_value, poke_return_value, mock_hook): + @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.get_conn') + @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key') + def test_poke(self, paginate_return_value, poke_return_value, mock_check, mock_get_conn): op = S3KeySizeSensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file') - mock_check_for_key = mock_hook.return_value.check_for_key - mock_hook.return_value.check_for_key.return_value = True + mock_check.return_value = True mock_paginator = mock.Mock() mock_paginator.paginate.return_value = [] - mock_conn = mock.Mock() - mock_conn.return_value.get_paginator.return_value = mock_paginator - mock_hook.return_value.get_conn = mock_conn + mock_get_conn.return_value.get_paginator.return_value = mock_paginator mock_paginator.paginate.return_value = [paginate_return_value] assert op.poke(None) is poke_return_value - mock_check_for_key.assert_called_once_with(op.bucket_key, op.bucket_name) + mock_check.assert_called_once_with(op.bucket_key, op.bucket_name) @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3KeySizeSensor.get_files', return_value=[]) - @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook') - def test_poke_wildcard(self, mock_hook, mock_get_files): + @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_wildcard_key') + def test_poke_wildcard(self, mock_check, mock_get_files): op = S3KeySizeSensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file', wildcard_match=True) - mock_check_for_wildcard_key = mock_hook.return_value.check_for_wildcard_key - mock_check_for_wildcard_key.return_value = False + mock_check.return_value = False assert not op.poke(None) - mock_check_for_wildcard_key.assert_called_once_with(op.bucket_key, op.bucket_name) + mock_check.assert_called_once_with(op.bucket_key, op.bucket_name) From ce061ac9d650406fc472be03db1e5ced6c35f4ea Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Thu, 10 Feb 2022 15:36:56 -0800 Subject: [PATCH 3/5] fixup! must resolve only after template rendering --- airflow/providers/amazon/aws/sensors/s3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/sensors/s3.py b/airflow/providers/amazon/aws/sensors/s3.py index 4b35f68eb24d4..1e621ab11c49d 100644 --- a/airflow/providers/amazon/aws/sensors/s3.py +++ b/airflow/providers/amazon/aws/sensors/s3.py @@ -103,7 +103,7 @@ def poke(self, context: 'Context'): self._resolve_bucket_and_key() self.log.info('Poking for key : s3://%s/%s', self.bucket_name, self.bucket_key) if self.wildcard_match: - return self.get_hook().f(self.bucket_key, self.bucket_name) + return self.get_hook().check_for_wildcard_key(self.bucket_key, self.bucket_name) return self.get_hook().check_for_key(self.bucket_key, self.bucket_name) def get_hook(self) -> S3Hook: From 1e40c0341f181db8d98353d5151d867716c6403e Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Thu, 10 Feb 2022 15:42:11 -0800 Subject: [PATCH 4/5] economize --- airflow/providers/amazon/aws/sensors/s3.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/airflow/providers/amazon/aws/sensors/s3.py b/airflow/providers/amazon/aws/sensors/s3.py index 1e621ab11c49d..2bda06639f5dc 100644 --- a/airflow/providers/amazon/aws/sensors/s3.py +++ b/airflow/providers/amazon/aws/sensors/s3.py @@ -77,27 +77,21 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.bucket_key = bucket_key self.bucket_name = bucket_name + self.bucket_key = bucket_key self.wildcard_match = wildcard_match self.aws_conn_id = aws_conn_id self.verify = verify self.hook: Optional[S3Hook] = None def _resolve_bucket_and_key(self): - """ - If key is URI we should parse the bucket and leave only key portion under - the ``bucket_key`` attr. - """ + """If key is URI, parse bucket""" if self.bucket_name is None: self.bucket_name, self.bucket_key = S3Hook.parse_s3_url(self.bucket_key) else: - parsed = urlparse(self.bucket_key) - if parsed.scheme != '' or parsed.netloc != '': - raise AirflowException( - 'If bucket_name is provided, bucket_key should be relative path from root level, ' - 'rather than a full s3:// url' - ) + parsed_url = urlparse(self.bucket_key) + if parsed_url.scheme != '' or parsed_url.netloc != '': + raise AirflowException('If bucket_name provided, bucket_key must be relative path, not URI.') def poke(self, context: 'Context'): self._resolve_bucket_and_key() From 923ed053cb3ea0360d93154bbf9632f71c8bedc1 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Thu, 10 Feb 2022 15:52:19 -0800 Subject: [PATCH 5/5] fixup! economize --- tests/providers/amazon/aws/sensors/test_s3_key.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/providers/amazon/aws/sensors/test_s3_key.py b/tests/providers/amazon/aws/sensors/test_s3_key.py index a12fed3eae73e..2f6237ebff642 100644 --- a/tests/providers/amazon/aws/sensors/test_s3_key.py +++ b/tests/providers/amazon/aws/sensors/test_s3_key.py @@ -104,36 +104,36 @@ def test_poke(self, mock_check): op = S3KeySensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file') mock_check.return_value = False - assert not op.poke(None) + assert op.poke(None) is False mock_check.assert_called_once_with(op.bucket_key, op.bucket_name) mock_check.return_value = True - assert op.poke(None) + assert op.poke(None) is True @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_wildcard_key') def test_poke_wildcard(self, mock_check): op = S3KeySensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file', wildcard_match=True) mock_check.return_value = False - assert not op.poke(None) + assert op.poke(None) is False mock_check.assert_called_once_with(op.bucket_key, op.bucket_name) mock_check.return_value = True - assert op.poke(None) + assert op.poke(None) is True class TestS3KeySizeSensor(unittest.TestCase): @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key', return_value=False) def test_poke_check_for_key_false(self, mock_check_for_key): op = S3KeySizeSensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file') - assert not op.poke(None) + assert op.poke(None) is False mock_check_for_key.assert_called_once_with(op.bucket_key, op.bucket_name) @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3KeySizeSensor.get_files', return_value=[]) @mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key', return_value=True) def test_poke_get_files_false(self, mock_check_for_key, mock_get_files): op = S3KeySizeSensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file') - assert not op.poke(None) + assert op.poke(None) is False mock_check_for_key.assert_called_once_with(op.bucket_key, op.bucket_name) mock_get_files.assert_called_once_with(s3_hook=op.get_hook()) @@ -167,5 +167,5 @@ def test_poke_wildcard(self, mock_check, mock_get_files): op = S3KeySizeSensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file', wildcard_match=True) mock_check.return_value = False - assert not op.poke(None) + assert op.poke(None) is False mock_check.assert_called_once_with(op.bucket_key, op.bucket_name)