diff --git a/airflow/providers/amazon/aws/sensors/s3.py b/airflow/providers/amazon/aws/sensors/s3.py index 40e31596d78c7..57c9393c0c583 100644 --- a/airflow/providers/amazon/aws/sensors/s3.py +++ b/airflow/providers/amazon/aws/sensors/s3.py @@ -86,7 +86,7 @@ def __init__( ): super().__init__(**kwargs) self.bucket_name = bucket_name - self.bucket_key = [bucket_key] if isinstance(bucket_key, str) else bucket_key + self.bucket_key = bucket_key self.wildcard_match = wildcard_match self.check_fn = check_fn self.aws_conn_id = aws_conn_id @@ -125,7 +125,10 @@ def _check_key(self, key): return True def poke(self, context: Context): - return all(self._check_key(key) for key in self.bucket_key) + if isinstance(self.bucket_key, str): + return self._check_key(self.bucket_key) + else: + return all(self._check_key(key) for key in self.bucket_key) def get_hook(self) -> S3Hook: """Create and return an S3Hook""" diff --git a/tests/providers/amazon/aws/sensors/test_s3_key.py b/tests/providers/amazon/aws/sensors/test_s3_key.py index 8d560e2c82044..f0832d3df9a89 100644 --- a/tests/providers/amazon/aws/sensors/test_s3_key.py +++ b/tests/providers/amazon/aws/sensors/test_s3_key.py @@ -126,6 +126,33 @@ def test_parse_bucket_key_from_jinja(self, mock_head_object): mock_head_object.assert_called_once_with("key", "bucket") + @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.head_object") + def test_parse_list_of_bucket_keys_from_jinja(self, mock_head_object): + mock_head_object.return_value = None + mock_head_object.side_effect = [{"ContentLength": 0}, {"ContentLength": 0}] + + Variable.set("test_bucket_key", ["s3://bucket/file1", "s3://bucket/file2"]) + + execution_date = timezone.datetime(2020, 1, 1) + + dag = DAG("test_s3_key", start_date=execution_date, render_template_as_native_obj=True) + op = S3KeySensor( + task_id="s3_key_sensor", + bucket_key="{{ var.value.test_bucket_key }}", + bucket_name=None, + dag=dag, + ) + + dag_run = DagRun(dag_id=dag.dag_id, execution_date=execution_date, run_id="test") + ti = TaskInstance(task=op) + ti.dag_run = dag_run + context = ti.get_template_context() + ti.render_templates(context) + op.poke(None) + + mock_head_object.assert_any_call("file1", "bucket") + mock_head_object.assert_any_call("file2", "bucket") + @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.head_object") def test_poke(self, mock_head_object): op = S3KeySensor(task_id="s3_key_sensor", bucket_key="s3://test_bucket/file")