Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions airflow/providers/amazon/aws/sensors/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,31 +77,24 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)

self.bucket_name = bucket_name
self.bucket_key = bucket_key
self.wildcard_match = wildcard_match
self.aws_conn_id = aws_conn_id
self.verify = verify
self.hook: Optional[S3Hook] = None

def poke(self, context: 'Context'):

def _resolve_bucket_and_key(self):
"""If key is URI, parse bucket"""
if self.bucket_name is None:
parsed_url = urlparse(self.bucket_key)
if parsed_url.netloc == '':
raise AirflowException('If key is a relative path from root, please provide a bucket_name')
self.bucket_name = parsed_url.netloc
self.bucket_key = parsed_url.path.lstrip('/')
self.bucket_name, self.bucket_key = S3Hook.parse_s3_url(self.bucket_key)
else:
parsed_url = urlparse(self.bucket_key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't have to happen in this PR but I wonder if S3Hook.parse_s3_url() could be augmented with some sort of configurable check between validating a URI vs. a relative path? Centralize some of this logic there without have urlparse() is different places. Anyway, just a thought.

if parsed_url.scheme != '' or parsed_url.netloc != '':
raise AirflowException(
'If bucket_name is provided, bucket_key'
' should be relative path from root'
' level, rather than a full s3:// url'
)
raise AirflowException('If bucket_name provided, bucket_key must be relative path, not URI.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise AirflowException('If bucket_name provided, bucket_key must be relative path, not URI.')
raise AirflowException('If bucket_name is provided, bucket_key must be a relative path, not a URI.')

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was trying to keep it on one line :) lemme see if that works


def poke(self, context: 'Context'):
self._resolve_bucket_and_key()
self.log.info('Poking for key : s3://%s/%s', self.bucket_name, self.bucket_key)
if self.wildcard_match:
return self.get_hook().check_for_wildcard_key(self.bucket_key, self.bucket_name)
Expand Down
72 changes: 33 additions & 39 deletions tests/providers/amazon/aws/sensors/test_s3_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def test_bucket_name_provided_and_bucket_key_is_s3_url(self):
['key', 'bucket', 'key', 'bucket'],
]
)
@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook')
def test_parse_bucket_key(self, key, bucket, parsed_key, parsed_bucket, mock_hook):
mock_hook.return_value.check_for_key.return_value = False
@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key')
def test_parse_bucket_key(self, key, bucket, parsed_key, parsed_bucket, mock_check):
mock_check.return_value = False

op = S3KeySensor(
task_id='s3_key_sensor',
Expand All @@ -73,9 +73,9 @@ def test_parse_bucket_key(self, key, bucket, parsed_key, parsed_bucket, mock_hoo
assert op.bucket_key == parsed_key
assert op.bucket_name == parsed_bucket

@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook')
def test_parse_bucket_key_from_jinja(self, mock_hook):
mock_hook.return_value.check_for_key.return_value = False
@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key')
def test_parse_bucket_key_from_jinja(self, mock_check):
mock_check.return_value = False

Variable.set("test_bucket_key", "s3://bucket/key")

Expand All @@ -94,49 +94,46 @@ def test_parse_bucket_key_from_jinja(self, mock_hook):
ti.dag_run = dag_run
context = ti.get_template_context()
ti.render_templates(context)

op.poke(None)

assert op.bucket_key == "key"
assert op.bucket_name == "bucket"

@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook')
def test_poke(self, mock_hook):
@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key')
def test_poke(self, mock_check):
op = S3KeySensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file')

mock_check_for_key = mock_hook.return_value.check_for_key
mock_check_for_key.return_value = False
assert not op.poke(None)
mock_check_for_key.assert_called_once_with(op.bucket_key, op.bucket_name)
mock_check.return_value = False
assert op.poke(None) is False
mock_check.assert_called_once_with(op.bucket_key, op.bucket_name)

mock_hook.return_value.check_for_key.return_value = True
assert op.poke(None)
mock_check.return_value = True
assert op.poke(None) is True

@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook')
def test_poke_wildcard(self, mock_hook):
@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_wildcard_key')
def test_poke_wildcard(self, mock_check):
op = S3KeySensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file', wildcard_match=True)

mock_check_for_wildcard_key = mock_hook.return_value.check_for_wildcard_key
mock_check_for_wildcard_key.return_value = False
assert not op.poke(None)
mock_check_for_wildcard_key.assert_called_once_with(op.bucket_key, op.bucket_name)
mock_check.return_value = False
assert op.poke(None) is False
mock_check.assert_called_once_with(op.bucket_key, op.bucket_name)

mock_check_for_wildcard_key.return_value = True
assert op.poke(None)
mock_check.return_value = True
assert op.poke(None) is True


class TestS3KeySizeSensor(unittest.TestCase):
@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key', return_value=False)
def test_poke_check_for_key_false(self, mock_check_for_key):
op = S3KeySizeSensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file')
assert not op.poke(None)
assert op.poke(None) is False
mock_check_for_key.assert_called_once_with(op.bucket_key, op.bucket_name)

@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3KeySizeSensor.get_files', return_value=[])
@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key', return_value=True)
def test_poke_get_files_false(self, mock_check_for_key, mock_get_files):
op = S3KeySizeSensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file')
assert not op.poke(None)
assert op.poke(None) is False
mock_check_for_key.assert_called_once_with(op.bucket_key, op.bucket_name)
mock_get_files.assert_called_once_with(s3_hook=op.get_hook())

Expand All @@ -150,28 +147,25 @@ def test_poke_get_files_false(self, mock_check_for_key, mock_get_files):
[{"Contents": [{"Size": 10}, {"Size": 10}]}, True],
]
)
@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook')
def test_poke(self, paginate_return_value, poke_return_value, mock_hook):
@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.get_conn')
@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key')
def test_poke(self, paginate_return_value, poke_return_value, mock_check, mock_get_conn):
op = S3KeySizeSensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file')

mock_check_for_key = mock_hook.return_value.check_for_key
mock_hook.return_value.check_for_key.return_value = True
mock_check.return_value = True
mock_paginator = mock.Mock()
mock_paginator.paginate.return_value = []
mock_conn = mock.Mock()

mock_conn.return_value.get_paginator.return_value = mock_paginator
mock_hook.return_value.get_conn = mock_conn
mock_get_conn.return_value.get_paginator.return_value = mock_paginator
mock_paginator.paginate.return_value = [paginate_return_value]
assert op.poke(None) is poke_return_value
mock_check_for_key.assert_called_once_with(op.bucket_key, op.bucket_name)
mock_check.assert_called_once_with(op.bucket_key, op.bucket_name)

@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3KeySizeSensor.get_files', return_value=[])
@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook')
def test_poke_wildcard(self, mock_hook, mock_get_files):
@mock.patch('airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_wildcard_key')
def test_poke_wildcard(self, mock_check, mock_get_files):
op = S3KeySizeSensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file', wildcard_match=True)

mock_check_for_wildcard_key = mock_hook.return_value.check_for_wildcard_key
mock_check_for_wildcard_key.return_value = False
assert not op.poke(None)
mock_check_for_wildcard_key.assert_called_once_with(op.bucket_key, op.bucket_name)
mock_check.return_value = False
assert op.poke(None) is False
mock_check.assert_called_once_with(op.bucket_key, op.bucket_name)