diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 4bfe177e8cc1d..697d1c50fde97 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -39,7 +39,7 @@ from airflow.executors import get_default_executor from airflow.stats import Stats from airflow.utils import timezone -from airflow.utils.dag_processing import list_py_file_paths +from airflow.utils.dag_processing import list_py_file_paths, correct_maybe_zipped from airflow.utils.db import provide_session from airflow.utils.helpers import pprinttable from airflow.utils.log.logging_mixin import LoggingMixin @@ -364,6 +364,8 @@ def collect_dags( FileLoadStat = namedtuple( 'FileLoadStat', "file duration dag_num task_num dags") + dag_folder = correct_maybe_zipped(dag_folder) + for filepath in list_py_file_paths(dag_folder, safe_mode=safe_mode, include_examples=include_examples): try: diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index f0adae27b4b24..37036ba9410d3 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -277,6 +277,20 @@ def get_dag(self, dag_id): return self.dag_id_to_simple_dag[dag_id] +def correct_maybe_zipped(fileloc): + """ + If the path contains a folder with a .zip suffix, then + the folder is treated as a zip archive and path to zip is returned. + """ + + _, archive, filename = re.search( + r'((.*\.zip){})?(.*)'.format(re.escape(os.sep)), fileloc).groups() + if archive and zipfile.is_zipfile(archive): + return archive + else: + return fileloc + + def list_py_file_paths(directory, safe_mode=True, include_examples=None): """ diff --git a/tests/utils/test_dag_processing.py b/tests/utils/test_dag_processing.py index ad33b7f64372d..611aaac30d33a 100644 --- a/tests/utils/test_dag_processing.py +++ b/tests/utils/test_dag_processing.py @@ -21,6 +21,7 @@ import sys import tempfile import unittest +import mock from datetime import timedelta from mock import MagicMock @@ -32,7 +33,7 @@ from airflow.models import DagBag, TaskInstance as TI from airflow.utils import timezone from airflow.utils.dag_processing import (DagFileProcessorAgent, DagFileProcessorManager, - SimpleTaskInstance) + SimpleTaskInstance, correct_maybe_zipped) from airflow.utils.db import create_session from airflow.utils.state import State @@ -326,3 +327,36 @@ def processor_factory(file_path, zombies): manager_process.join() self.assertTrue(os.path.isfile(log_file_loc)) + + +class TestCorrectMaybeZipped(unittest.TestCase): + @mock.patch("zipfile.is_zipfile") + def test_correct_maybe_zipped_normal_file(self, mocked_is_zipfile): + path = '/path/to/some/file.txt' + mocked_is_zipfile.return_value = False + + dag_folder = correct_maybe_zipped(path) + + self.assertEqual(dag_folder, path) + + @mock.patch("zipfile.is_zipfile") + def test_correct_maybe_zipped_normal_file_with_zip_in_name(self, mocked_is_zipfile): + path = '/path/to/fakearchive.zip.other/file.txt' + mocked_is_zipfile.return_value = False + + dag_folder = correct_maybe_zipped(path) + + self.assertEqual(dag_folder, path) + + @mock.patch("zipfile.is_zipfile") + def test_correct_maybe_zipped_archive(self, mocked_is_zipfile): + path = '/path/to/archive.zip/deep/path/to/file.txt' + mocked_is_zipfile.return_value = True + + dag_folder = correct_maybe_zipped(path) + + assert mocked_is_zipfile.call_count == 1 + (args, kwargs) = mocked_is_zipfile.call_args_list[0] + self.assertEqual('/path/to/archive.zip', args[0]) + + self.assertEqual(dag_folder, '/path/to/archive.zip')