Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1439,6 +1439,9 @@ class DagModel(Base):
# Foreign key to the latest pickle_id
pickle_id = Column(Integer)
# The location of the file containing the DAG object
# Note: Do not depend on fileloc pointing to a file; in the case of a
# packaged DAG, it will point to the subpath of the DAG within the
# associated zip.
fileloc = Column(String(2000))
# String representing the owners
owners = Column(String(2000))
Expand Down
3 changes: 1 addition & 2 deletions airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def get_dag(self, dag_id):
):
# Reprocess source file
found_dags = self.process_file(
filepath=orm_dag.fileloc, only_if_updated=False)
filepath=correct_maybe_zipped(orm_dag.fileloc), only_if_updated=False)

# If the source file no longer exports `dag_id`, delete it from self.dags
if found_dags and dag_id in [found_dag.dag_id for found_dag in found_dags]:
Expand Down Expand Up @@ -355,7 +355,6 @@ def collect_dags(
"""
start_dttm = timezone.utcnow()
dag_folder = dag_folder or self.dag_folder

# Used to store stats around DagBag processing
stats = []
FileLoadStat = namedtuple(
Expand Down
80 changes: 74 additions & 6 deletions tests/models/test_dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# specific language governing permissions and limitations
# under the License.

from datetime import datetime, timezone
import inspect
import os
import shutil
Expand All @@ -32,6 +33,7 @@
from airflow.utils.db import create_session
from airflow.utils.state import State
from tests.models import TEST_DAGS_FOLDER, DEFAULT_DATE
import airflow.example_dags


class DagBagTest(unittest.TestCase):
Expand Down Expand Up @@ -190,20 +192,86 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True):
def test_get_dag_fileloc(self):
"""
Test that fileloc is correctly set when we load example DAGs,
specifically SubDAGs.
specifically SubDAGs and packaged DAGs.
"""
dagbag = models.DagBag(include_examples=True)

expected = {
'example_bash_operator': 'example_bash_operator.py',
'example_subdag_operator': 'example_subdag_operator.py',
'example_subdag_operator.section-1': 'subdags/subdag.py'
'example_bash_operator': 'airflow/example_dags/example_bash_operator.py',
'example_subdag_operator': 'airflow/example_dags/example_subdag_operator.py',
'example_subdag_operator.section-1': 'airflow/example_dags/subdags/subdag.py',
'test_zip_dag': 'tests/dags/test_zip.zip/test_zip.py'
}

for dag_id, path in expected.items():
dag = dagbag.get_dag(dag_id)
self.assertTrue(
dag.fileloc.endswith('airflow/example_dags/' + path))
self.assertTrue(dag.fileloc.endswith(path))

@patch.object(DagModel, "get_current")
def test_refresh_py_dag(self, mock_dagmodel):
"""
Test that we can refresh an ordinary .py DAG
"""
EXAMPLE_DAGS_FOLDER = airflow.example_dags.__path__[0]

dag_id = "example_bash_operator"
fileloc = os.path.realpath(
os.path.join(EXAMPLE_DAGS_FOLDER, "example_bash_operator.py")
)

mock_dagmodel.return_value = DagModel()
mock_dagmodel.return_value.last_expired = datetime.max.replace(
tzinfo=timezone.utc
)
mock_dagmodel.return_value.fileloc = fileloc

class TestDagBag(DagBag):
process_file_calls = 0

def process_file(self, filepath, only_if_updated=True, safe_mode=True):
if filepath == fileloc:
TestDagBag.process_file_calls += 1
return super().process_file(filepath, only_if_updated, safe_mode)

dagbag = TestDagBag(dag_folder=self.empty_dir, include_examples=True)

self.assertEqual(1, dagbag.process_file_calls)
dag = dagbag.get_dag(dag_id)
self.assertIsNotNone(dag)
self.assertEqual(dag_id, dag.dag_id)
self.assertEqual(2, dagbag.process_file_calls)

@patch.object(DagModel, "get_current")
def test_refresh_packaged_dag(self, mock_dagmodel):
"""
Test that we can refresh a packaged DAG
"""
dag_id = "test_zip_dag"
fileloc = os.path.realpath(
os.path.join(TEST_DAGS_FOLDER, "test_zip.zip/test_zip.py")
)

mock_dagmodel.return_value = DagModel()
mock_dagmodel.return_value.last_expired = datetime.max.replace(
tzinfo=timezone.utc
)
mock_dagmodel.return_value.fileloc = fileloc

class TestDagBag(DagBag):
process_file_calls = 0

def process_file(self, filepath, only_if_updated=True, safe_mode=True):
if filepath in fileloc:
TestDagBag.process_file_calls += 1
return super().process_file(filepath, only_if_updated, safe_mode)

dagbag = TestDagBag(dag_folder=os.path.realpath(TEST_DAGS_FOLDER), include_examples=False)

self.assertEqual(1, dagbag.process_file_calls)
dag = dagbag.get_dag(dag_id)
self.assertIsNotNone(dag)
self.assertEqual(dag_id, dag.dag_id)
self.assertEqual(2, dagbag.process_file_calls)

def process_dag(self, create_dag):
"""
Expand Down