diff --git a/airflow/models/dag.py b/airflow/models/dag.py index e17c960d2e9f6..89c4c93c41e8c 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1439,6 +1439,9 @@ class DagModel(Base): # Foreign key to the latest pickle_id pickle_id = Column(Integer) # The location of the file containing the DAG object + # Note: Do not depend on fileloc pointing to a file; in the case of a + # packaged DAG, it will point to the subpath of the DAG within the + # associated zip. fileloc = Column(String(2000)) # String representing the owners owners = Column(String(2000)) diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 5325588cfd435..477624c5e61e5 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -130,7 +130,7 @@ def get_dag(self, dag_id): ): # Reprocess source file found_dags = self.process_file( - filepath=orm_dag.fileloc, only_if_updated=False) + filepath=correct_maybe_zipped(orm_dag.fileloc), only_if_updated=False) # If the source file no longer exports `dag_id`, delete it from self.dags if found_dags and dag_id in [found_dag.dag_id for found_dag in found_dags]: @@ -355,7 +355,6 @@ def collect_dags( """ start_dttm = timezone.utcnow() dag_folder = dag_folder or self.dag_folder - # Used to store stats around DagBag processing stats = [] FileLoadStat = namedtuple( diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index eb9ebe33ccd13..f8c945ffbfeae 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -17,6 +17,7 @@ # specific language governing permissions and limitations # under the License. +from datetime import datetime, timezone import inspect import os import shutil @@ -32,6 +33,7 @@ from airflow.utils.db import create_session from airflow.utils.state import State from tests.models import TEST_DAGS_FOLDER, DEFAULT_DATE +import airflow.example_dags class DagBagTest(unittest.TestCase): @@ -190,20 +192,86 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): def test_get_dag_fileloc(self): """ Test that fileloc is correctly set when we load example DAGs, - specifically SubDAGs. + specifically SubDAGs and packaged DAGs. """ dagbag = models.DagBag(include_examples=True) expected = { - 'example_bash_operator': 'example_bash_operator.py', - 'example_subdag_operator': 'example_subdag_operator.py', - 'example_subdag_operator.section-1': 'subdags/subdag.py' + 'example_bash_operator': 'airflow/example_dags/example_bash_operator.py', + 'example_subdag_operator': 'airflow/example_dags/example_subdag_operator.py', + 'example_subdag_operator.section-1': 'airflow/example_dags/subdags/subdag.py', + 'test_zip_dag': 'tests/dags/test_zip.zip/test_zip.py' } for dag_id, path in expected.items(): dag = dagbag.get_dag(dag_id) - self.assertTrue( - dag.fileloc.endswith('airflow/example_dags/' + path)) + self.assertTrue(dag.fileloc.endswith(path)) + + @patch.object(DagModel, "get_current") + def test_refresh_py_dag(self, mock_dagmodel): + """ + Test that we can refresh an ordinary .py DAG + """ + EXAMPLE_DAGS_FOLDER = airflow.example_dags.__path__[0] + + dag_id = "example_bash_operator" + fileloc = os.path.realpath( + os.path.join(EXAMPLE_DAGS_FOLDER, "example_bash_operator.py") + ) + + mock_dagmodel.return_value = DagModel() + mock_dagmodel.return_value.last_expired = datetime.max.replace( + tzinfo=timezone.utc + ) + mock_dagmodel.return_value.fileloc = fileloc + + class TestDagBag(DagBag): + process_file_calls = 0 + + def process_file(self, filepath, only_if_updated=True, safe_mode=True): + if filepath == fileloc: + TestDagBag.process_file_calls += 1 + return super().process_file(filepath, only_if_updated, safe_mode) + + dagbag = TestDagBag(dag_folder=self.empty_dir, include_examples=True) + + self.assertEqual(1, dagbag.process_file_calls) + dag = dagbag.get_dag(dag_id) + self.assertIsNotNone(dag) + self.assertEqual(dag_id, dag.dag_id) + self.assertEqual(2, dagbag.process_file_calls) + + @patch.object(DagModel, "get_current") + def test_refresh_packaged_dag(self, mock_dagmodel): + """ + Test that we can refresh a packaged DAG + """ + dag_id = "test_zip_dag" + fileloc = os.path.realpath( + os.path.join(TEST_DAGS_FOLDER, "test_zip.zip/test_zip.py") + ) + + mock_dagmodel.return_value = DagModel() + mock_dagmodel.return_value.last_expired = datetime.max.replace( + tzinfo=timezone.utc + ) + mock_dagmodel.return_value.fileloc = fileloc + + class TestDagBag(DagBag): + process_file_calls = 0 + + def process_file(self, filepath, only_if_updated=True, safe_mode=True): + if filepath in fileloc: + TestDagBag.process_file_calls += 1 + return super().process_file(filepath, only_if_updated, safe_mode) + + dagbag = TestDagBag(dag_folder=os.path.realpath(TEST_DAGS_FOLDER), include_examples=False) + + self.assertEqual(1, dagbag.process_file_calls) + dag = dagbag.get_dag(dag_id) + self.assertIsNotNone(dag) + self.assertEqual(dag_id, dag.dag_id) + self.assertEqual(2, dagbag.process_file_calls) def process_dag(self, create_dag): """