diff --git a/airflow/__init__.py b/airflow/__init__.py index 4c4509e00e15c..c4d6b238dda3e 100644 --- a/airflow/__init__.py +++ b/airflow/__init__.py @@ -75,14 +75,17 @@ class AirflowMacroPlugin(object): def __init__(self, namespace): self.namespace = namespace -from airflow import operators + +from airflow import operators # noqa: E402 from airflow import sensors # noqa: E402 -from airflow import hooks -from airflow import executors -from airflow import macros +from airflow import hooks # noqa: E402 +from airflow import executors # noqa: E402 +from airflow import macros # noqa: E402 +from airflow.dag import fetchers # noqa: E402 operators._integrate_plugins() -sensors._integrate_plugins() # noqa: E402 +sensors._integrate_plugins() hooks._integrate_plugins() executors._integrate_plugins() macros._integrate_plugins() +fetchers._integrate_plugins() diff --git a/airflow/dag/fetchers/__init__.py b/airflow/dag/fetchers/__init__.py new file mode 100644 index 0000000000000..861e63751acf3 --- /dev/null +++ b/airflow/dag/fetchers/__init__.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.dag.fetchers.filesystem import FileSystemDagFetcher +from airflow.dag.fetchers.hdfs import HDFSDagFetcher +from airflow.dag.fetchers.s3 import S3DagFetcher +from airflow.dag.fetchers.gcs import GCSDagFetcher +from airflow.dag.fetchers.git import GitDagFetcher + + +def get_dag_fetcher(dagbag, dags_uri): + """ + Factory method that returns an instance of the right + DagFetcher, based on the dags_uri prefix. + + Any prefix that does not match keys in the dag_fetchers + dict (or no prefix at all) defaults to FileSystemDagFetcher. + """ + log = LoggingMixin().log + + dag_fetchers = dict( + hdfs=HDFSDagFetcher, + s3=S3DagFetcher, + gcs=GCSDagFetcher, + git=GitDagFetcher) + + uri_schema = dags_uri.split(':')[0] + + if uri_schema not in dag_fetchers: + log.debug('Defaulting to FileSystemDagFetcher') + return FileSystemDagFetcher(dagbag, dags_uri) + + return dag_fetchers[uri_schema](dagbag, dags_uri) + + +def _integrate_plugins(): + """Integrate plugins to the context.""" + from airflow.plugins_manager import dag_fetchers_modules + for dag_fetchers_module in dag_fetchers_modules: + sys.modules[dag_fetchers_module.__name__] = dag_fetchers_module + globals()[dag_fetchers_module._name] = dag_fetchers_module diff --git a/airflow/dag/fetchers/base.py b/airflow/dag/fetchers/base.py new file mode 100644 index 0000000000000..93faae8af35b7 --- /dev/null +++ b/airflow/dag/fetchers/base.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import namedtuple + +from airflow.utils.log.logging_mixin import LoggingMixin + + +class BaseDagFetcher(LoggingMixin): + """ + Abstract base class for all DagFetchers. + + A DagFetcher's responsability is to find the dags in + the dags_uri and add them to the dagbag. + + The fetch method must be implemented by any given DagFetcher, + and return the list of per dag statistics. It must also + implement a process_file method, which is used to reprocess + a DAG. + + :param dagbag: a DagBag instance, which we will populate + :type dagbag: DagBag + :param dags_uri: the URI for the dags folder. The schema + prefix determines the child that will be instantiated + :type dags_uri: string + :param safe_mode: if dag files should be processed with safe_mode + :type safe_mode: boolean + """ + FileLoadStat = namedtuple( + 'FileLoadStat', 'file duration dag_num task_num dags') + + def __init__(self, dagbag, dags_uri=None, safe_mode=True): + self.found_dags = [] + self.stats = [] + self.dagbag = dagbag + self.dags_uri = dags_uri + self.safe_mode = safe_mode + + def process_file(self, filepath, only_if_updated=True): + """ + This method is used to process/reprocess a single file and + must be implemented by all DagFetchers. + + Must return the dags in the file. + """ + raise NotImplementedError() + + def fetch(self, only_if_updated=True): + """ + This is the main method to derive when creating a DagFetcher. + """ + raise NotImplementedError() diff --git a/airflow/dag/fetchers/filesystem.py b/airflow/dag/fetchers/filesystem.py new file mode 100644 index 0000000000000..57d202db2bd02 --- /dev/null +++ b/airflow/dag/fetchers/filesystem.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from datetime import datetime + +import sys +import os +import re +import zipfile +import hashlib +import imp +import importlib + +import airflow +from airflow import configuration +from airflow.utils import timezone +from airflow.utils.timeout import timeout +from airflow.exceptions import AirflowDagCycleException +from airflow.dag.fetchers.base import BaseDagFetcher + + +class FileSystemDagFetcher(BaseDagFetcher): + """ + Fetches dags from the local file system, by walking the dags_uri + folder on the local disk, looking for .py and .zip files. + + :param dagbag: a DagBag instance, which we will populate + :type dagbag: DagBag + :param dags_uri: the URI for the dags folder. The schema + prefix determines the child that will be instantiated + :type dags_uri: string + :param safe_mode: if dag files should be processed with safe_mode + :type safe_mode: boolean + """ + def process_file(self, filepath, only_if_updated=True): + """ + Given a path to a python module or zip file, this method imports + the module and look for dag objects within it. + """ + found_dags = [] + # if the source file no longer exists in the DB or in the filesystem, + # return an empty list + # todo: raise exception? + if filepath is None or not os.path.isfile(filepath): + return found_dags + + try: + # This failed before in what may have been a git sync + # race condition + file_last_changed = datetime.fromtimestamp( + os.path.getmtime(filepath)) + if only_if_updated \ + and filepath in self.dagbag.file_last_changed \ + and file_last_changed == self.dagbag.file_last_changed[filepath]: + return found_dags + + except Exception as e: + self.log.exception(e) + return found_dags + + mods = [] + if not zipfile.is_zipfile(filepath): + if self.safe_mode and os.path.isfile(filepath): + with open(filepath, 'rb') as f: + content = f.read() + if not all([s in content for s in (b'DAG', b'airflow')]): + self.dagbag.file_last_changed[filepath] = file_last_changed + return found_dags + + self.log.debug("Importing %s", filepath) + org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1]) + mod_name = ('unusual_prefix_' + + hashlib.sha1(filepath.encode('utf-8')).hexdigest() + + '_' + org_mod_name) + + if mod_name in sys.modules: + del sys.modules[mod_name] + + with timeout(configuration.getint('core', "DAGBAG_IMPORT_TIMEOUT")): + try: + m = imp.load_source(mod_name, filepath) + mods.append(m) + except Exception as e: + self.log.exception("Failed to import: %s", filepath) + self.dagbag.import_errors[filepath] = str(e) + self.dagbag.file_last_changed[filepath] = file_last_changed + + else: + zip_file = zipfile.ZipFile(filepath) + for mod in zip_file.infolist(): + head, _ = os.path.split(mod.filename) + mod_name, ext = os.path.splitext(mod.filename) + if not head and (ext == '.py' or ext == '.pyc'): + if mod_name == '__init__': + self.log.warning("Found __init__.%s at root of %s", ext, filepath) + if self.safe_mode: + with zip_file.open(mod.filename) as zf: + self.log.debug("Reading %s from %s", mod.filename, filepath) + content = zf.read() + if not all([s in content for s in (b'DAG', b'airflow')]): + self.dagbag.file_last_changed[filepath] = ( + file_last_changed) + # todo: create ignore list + return found_dags + + if mod_name in sys.modules: + del sys.modules[mod_name] + + try: + sys.path.insert(0, filepath) + m = importlib.import_module(mod_name) + mods.append(m) + except Exception as e: + self.log.exception("Failed to import: %s", filepath) + self.dagbag.import_errors[filepath] = str(e) + self.dagbag.file_last_changed[filepath] = file_last_changed + + for m in mods: + for dag in list(m.__dict__.values()): + if isinstance(dag, airflow.models.DAG): + if not dag.full_filepath: + dag.full_filepath = filepath + if dag.fileloc != filepath: + dag.fileloc = filepath + try: + dag.is_subdag = False + self.dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag) + found_dags.append(dag) + found_dags += dag.subdags + except AirflowDagCycleException as cycle_exception: + self.log.exception("Failed to bag_dag: %s", dag.full_filepath) + self.dagbag.import_errors[dag.full_filepath] = \ + str(cycle_exception) + self.dagbag.file_last_changed[dag.full_filepath] = \ + file_last_changed + + self.dagbag.file_last_changed[filepath] = file_last_changed + return found_dags + + def fetch(self, only_if_updated=True): + """ + Walks the dags_folder (self.dags_uri) looking for files to process + """ + if os.path.isfile(self.dags_uri): + self.process_file(self.dags_uri, only_if_updated=only_if_updated) + elif os.path.isdir(self.dags_uri): + patterns = [] + for root, dirs, files in os.walk(self.dags_uri, followlinks=True): + ignore_file = [f for f in files if f == '.airflowignore'] + if ignore_file: + f = open(os.path.join(root, ignore_file[0]), 'r') + patterns += [p for p in f.read().split('\n') if p] + f.close() + for f in files: + try: + filepath = os.path.join(root, f) + if not os.path.isfile(filepath): + continue + mod_name, file_ext = os.path.splitext( + os.path.split(filepath)[-1]) + if file_ext != '.py' and not zipfile.is_zipfile(filepath): + continue + if not any( + [re.findall(p, filepath) for p in patterns]): + ts = timezone.utcnow() + found_dags = self.process_file( + filepath, only_if_updated=only_if_updated) + + td = timezone.utcnow() - ts + td = td.total_seconds() + ( + float(td.microseconds) / 1000000) + self.stats.append(self.FileLoadStat( + filepath.replace(self.dags_uri, ''), + td, + len(found_dags), + sum([len(dag.tasks) for dag in found_dags]), + str([dag.dag_id for dag in found_dags]), + )) + except Exception as e: + self.log.exception(e) + + return self.stats diff --git a/airflow/dag/fetchers/gcs.py b/airflow/dag/fetchers/gcs.py new file mode 100644 index 0000000000000..4e22403a72aaf --- /dev/null +++ b/airflow/dag/fetchers/gcs.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from airflow.dag.fetchers.base import BaseDagFetcher + + +class GCSDagFetcher(BaseDagFetcher): + """ + GCSDagFetcher - Not Implemented + """ diff --git a/airflow/dag/fetchers/git.py b/airflow/dag/fetchers/git.py new file mode 100644 index 0000000000000..5b1a7b2269dd7 --- /dev/null +++ b/airflow/dag/fetchers/git.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from airflow.dag.fetchers.base import BaseDagFetcher + + +class GitDagFetcher(BaseDagFetcher): + """ + GitDagFetcher - Not Implemented + """ diff --git a/airflow/dag/fetchers/hdfs.py b/airflow/dag/fetchers/hdfs.py new file mode 100644 index 0000000000000..0495b1d2c9815 --- /dev/null +++ b/airflow/dag/fetchers/hdfs.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from airflow.dag.fetchers.base import BaseDagFetcher + + +class HDFSDagFetcher(BaseDagFetcher): + """ + HDFSDagFetecher - Not Implemented + """ diff --git a/airflow/dag/fetchers/s3.py b/airflow/dag/fetchers/s3.py new file mode 100644 index 0000000000000..56eacd3d1e0d4 --- /dev/null +++ b/airflow/dag/fetchers/s3.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from airflow.dag.fetchers.base import BaseDagFetcher + + +class S3DagFetcher(BaseDagFetcher): + """ + S3DagFetcher - Not Implemented + """ diff --git a/airflow/models.py b/airflow/models.py index c1b608afbbef4..33fd5f3aa1334 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -22,16 +22,13 @@ from builtins import str from builtins import object, bytes import copy -from collections import namedtuple, defaultdict +from collections import defaultdict from datetime import timedelta import dill import functools import getpass -import imp -import importlib import itertools -import zipfile import jinja2 import json import logging @@ -63,6 +60,7 @@ from airflow import settings, utils from airflow.executors import GetDefaultExecutor, LocalExecutor from airflow import configuration +from airflow.dag.fetchers import get_dag_fetcher from airflow.exceptions import ( AirflowDagCycleException, AirflowException, AirflowSkipException, AirflowTaskTimeout ) @@ -206,13 +204,13 @@ def __init__( self.file_last_changed = {} self.executor = executor self.import_errors = {} - + self.dag_fetcher = get_dag_fetcher(self, dag_folder) if include_examples: example_dag_folder = os.path.join( os.path.dirname(__file__), 'example_dags') - self.collect_dags(example_dag_folder) - self.collect_dags(dag_folder) + self.collect_dags(get_dag_fetcher(self, example_dag_folder)) + self.collect_dags(self.dag_fetcher) def size(self): """ @@ -241,8 +239,8 @@ def get_dag(self, dag_id): ) ): # Reprocess source file - found_dags = self.process_file( - filepath=orm_dag.fileloc, only_if_updated=False) + found_dags = self.dag_fetcher.process_file( + orm_dag.fileloc, only_if_updated=False) # If the source file no longer exports `dag_id`, delete it from self.dags if found_dags and dag_id in [dag.dag_id for dag in found_dags]: @@ -251,111 +249,6 @@ def get_dag(self, dag_id): del self.dags[dag_id] return self.dags.get(dag_id) - def process_file(self, filepath, only_if_updated=True, safe_mode=True): - """ - Given a path to a python module or zip file, this method imports - the module and look for dag objects within it. - """ - found_dags = [] - - # if the source file no longer exists in the DB or in the filesystem, - # return an empty list - # todo: raise exception? - if filepath is None or not os.path.isfile(filepath): - return found_dags - - try: - # This failed before in what may have been a git sync - # race condition - file_last_changed_on_disk = datetime.fromtimestamp(os.path.getmtime(filepath)) - if only_if_updated \ - and filepath in self.file_last_changed \ - and file_last_changed_on_disk == self.file_last_changed[filepath]: - return found_dags - - except Exception as e: - self.log.exception(e) - return found_dags - - mods = [] - if not zipfile.is_zipfile(filepath): - if safe_mode and os.path.isfile(filepath): - with open(filepath, 'rb') as f: - content = f.read() - if not all([s in content for s in (b'DAG', b'airflow')]): - self.file_last_changed[filepath] = file_last_changed_on_disk - return found_dags - - self.log.debug("Importing %s", filepath) - org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1]) - mod_name = ('unusual_prefix_' + - hashlib.sha1(filepath.encode('utf-8')).hexdigest() + - '_' + org_mod_name) - - if mod_name in sys.modules: - del sys.modules[mod_name] - - with timeout(configuration.getint('core', "DAGBAG_IMPORT_TIMEOUT")): - try: - m = imp.load_source(mod_name, filepath) - mods.append(m) - except Exception as e: - self.log.exception("Failed to import: %s", filepath) - self.import_errors[filepath] = str(e) - self.file_last_changed[filepath] = file_last_changed_on_disk - - else: - zip_file = zipfile.ZipFile(filepath) - for mod in zip_file.infolist(): - head, _ = os.path.split(mod.filename) - mod_name, ext = os.path.splitext(mod.filename) - if not head and (ext == '.py' or ext == '.pyc'): - if mod_name == '__init__': - self.log.warning("Found __init__.%s at root of %s", ext, filepath) - if safe_mode: - with zip_file.open(mod.filename) as zf: - self.log.debug("Reading %s from %s", mod.filename, filepath) - content = zf.read() - if not all([s in content for s in (b'DAG', b'airflow')]): - self.file_last_changed[filepath] = ( - file_last_changed_on_disk) - # todo: create ignore list - return found_dags - - if mod_name in sys.modules: - del sys.modules[mod_name] - - try: - sys.path.insert(0, filepath) - m = importlib.import_module(mod_name) - mods.append(m) - except Exception as e: - self.log.exception("Failed to import: %s", filepath) - self.import_errors[filepath] = str(e) - self.file_last_changed[filepath] = file_last_changed_on_disk - - for m in mods: - for dag in list(m.__dict__.values()): - if isinstance(dag, DAG): - if not dag.full_filepath: - dag.full_filepath = filepath - if dag.fileloc != filepath: - dag.fileloc = filepath - try: - dag.is_subdag = False - self.bag_dag(dag, parent_dag=dag, root_dag=dag) - found_dags.append(dag) - found_dags += dag.subdags - except AirflowDagCycleException as cycle_exception: - self.log.exception("Failed to bag_dag: %s", dag.full_filepath) - self.import_errors[dag.full_filepath] = str(cycle_exception) - self.file_last_changed[dag.full_filepath] = \ - file_last_changed_on_disk - - - self.file_last_changed[filepath] = file_last_changed_on_disk - return found_dags - @provide_session def kill_zombies(self, session=None): """ @@ -427,10 +320,9 @@ def bag_dag(self, dag, parent_dag, root_dag): del self.dags[subdag.dag_id] raise cycle_exception - def collect_dags( self, - dag_folder=None, + dag_fetcher=None, only_if_updated=True): """ Given a file path or a folder, this method looks for python modules, @@ -442,49 +334,10 @@ def collect_dags( in the file. """ start_dttm = timezone.utcnow() - dag_folder = dag_folder or self.dag_folder - - # Used to store stats around DagBag processing - stats = [] - FileLoadStat = namedtuple( - 'FileLoadStat', "file duration dag_num task_num dags") - if os.path.isfile(dag_folder): - self.process_file(dag_folder, only_if_updated=only_if_updated) - elif os.path.isdir(dag_folder): - patterns = [] - for root, dirs, files in os.walk(dag_folder, followlinks=True): - ignore_file = [f for f in files if f == '.airflowignore'] - if ignore_file: - f = open(os.path.join(root, ignore_file[0]), 'r') - patterns += [p for p in f.read().split('\n') if p] - f.close() - for f in files: - try: - filepath = os.path.join(root, f) - if not os.path.isfile(filepath): - continue - mod_name, file_ext = os.path.splitext( - os.path.split(filepath)[-1]) - if file_ext != '.py' and not zipfile.is_zipfile(filepath): - continue - if not any( - [re.findall(p, filepath) for p in patterns]): - ts = timezone.utcnow() - found_dags = self.process_file( - filepath, only_if_updated=only_if_updated) - - td = timezone.utcnow() - ts - td = td.total_seconds() + ( - float(td.microseconds) / 1000000) - stats.append(FileLoadStat( - filepath.replace(dag_folder, ''), - td, - len(found_dags), - sum([len(dag.tasks) for dag in found_dags]), - str([dag.dag_id for dag in found_dags]), - )) - except Exception as e: - self.log.exception(e) + dag_fetcher = dag_fetcher or self.dag_fetcher + + stats = dag_fetcher.fetch(only_if_updated=only_if_updated) + Stats.gauge( 'collect_dags', (timezone.utcnow() - start_dttm).total_seconds(), 1) Stats.gauge( diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py index aaae4230b7cb3..939f025a24008 100644 --- a/airflow/plugins_manager.py +++ b/airflow/plugins_manager.py @@ -40,6 +40,7 @@ class AirflowPlugin(object): hooks = [] executors = [] macros = [] + dag_fetchers = [] admin_views = [] flask_blueprints = [] menu_links = [] @@ -108,6 +109,7 @@ def make_module(name, objects): hooks_modules = [] executors_modules = [] macros_modules = [] +dag_fetchers_modules = [] # Plugin components to integrate directly admin_views = [] @@ -124,6 +126,8 @@ def make_module(name, objects): executors_modules.append( make_module('airflow.executors.' + p.name, p.executors)) macros_modules.append(make_module('airflow.macros.' + p.name, p.macros)) + dag_fetchers_modules.append( + make_module('airflow.dag.fetchers.' + p.name, p.dag_fetchers)) admin_views.extend(p.admin_views) flask_blueprints.extend(p.flask_blueprints) diff --git a/tests/models.py b/tests/models.py index 5d8184c575c1c..42259ed238774 100644 --- a/tests/models.py +++ b/tests/models.py @@ -38,6 +38,12 @@ from airflow.models import clear_task_instances from airflow.models import XCom from airflow.models import Connection +from airflow.dag.fetchers import FileSystemDagFetcher +from airflow.dag.fetchers import HDFSDagFetcher +from airflow.dag.fetchers import S3DagFetcher +from airflow.dag.fetchers import GCSDagFetcher +from airflow.dag.fetchers import GitDagFetcher +from airflow.dag.fetchers import get_dag_fetcher from airflow.operators.dummy_operator import DummyOperator from airflow.operators.bash_operator import BashOperator from airflow.operators.python_operator import PythonOperator @@ -920,7 +926,7 @@ def test_get_non_existing_dag(self): non_existing_dag_id = "non_existing_dag_id" self.assertIsNone(dagbag.get_dag(non_existing_dag_id)) - def test_process_file_that_contains_multi_bytes_char(self): + def test_process_local_file_that_contains_multi_bytes_char(self): """ test that we're able to parse file that contains multi-byte char """ @@ -929,42 +935,35 @@ def test_process_file_that_contains_multi_bytes_char(self): f.flush() dagbag = models.DagBag(include_examples=True) - self.assertEqual([], dagbag.process_file(f.name)) + fsfetcher = FileSystemDagFetcher(dagbag) + + self.assertEqual([], fsfetcher.process_file(f.name)) def test_zip(self): """ test the loading of a DAG within a zip file that includes dependencies """ dagbag = models.DagBag() - dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, "test_zip.zip")) + fsfetcher = FileSystemDagFetcher(dagbag) + fsfetcher.process_file(os.path.join(TEST_DAGS_FOLDER, "test_zip.zip")) self.assertTrue(dagbag.get_dag("test_zip_dag")) - @patch.object(DagModel,'get_current') - def test_get_dag_without_refresh(self, mock_dagmodel): + def test_get_dag_fetcher(self): """ - Test that, once a DAG is loaded, it doesn't get refreshed again if it - hasn't been expired. + Test that get_dag_fetcher returns the correct dag fetchers. """ - dag_id = 'example_bash_operator' - - mock_dagmodel.return_value = DagModel() - mock_dagmodel.return_value.last_expired = None - mock_dagmodel.return_value.fileloc = 'foo' - - class TestDagBag(models.DagBag): - process_file_calls = 0 - def process_file(self, filepath, only_if_updated=True, safe_mode=True): - if 'example_bash_operator.py' == os.path.basename(filepath): - TestDagBag.process_file_calls += 1 - super(TestDagBag, self).process_file(filepath, only_if_updated, safe_mode) - - dagbag = TestDagBag(include_examples=True) - processed_files = dagbag.process_file_calls - - # Should not call process_file agani, since it's already loaded during init. - self.assertEqual(1, dagbag.process_file_calls) - self.assertIsNotNone(dagbag.get_dag(dag_id)) - self.assertEqual(1, dagbag.process_file_calls) + dagbag = models.DagBag() + default_fetcher = get_dag_fetcher(dagbag, '/a/local/path/without/schema/dags') + hdfs_fetcher = get_dag_fetcher(dagbag, 'hdfs://host:optional-port/dags') + s3_fetcher = get_dag_fetcher(dagbag, 's3://bucket/dags') + gcs_fetcher = get_dag_fetcher(dagbag, 'gcs://bucket/dags') + git_fetcher = get_dag_fetcher(dagbag, 'git://github.com/apache/airflow.git') + + self.assertIsInstance(default_fetcher, FileSystemDagFetcher) + self.assertIsInstance(hdfs_fetcher, HDFSDagFetcher) + self.assertIsInstance(s3_fetcher, S3DagFetcher) + self.assertIsInstance(gcs_fetcher, GCSDagFetcher) + self.assertIsInstance(git_fetcher, GitDagFetcher) def test_get_dag_fileloc(self): """ @@ -996,7 +995,8 @@ def process_dag(self, create_dag): f.flush() dagbag = models.DagBag(include_examples=False) - found_dags = dagbag.process_file(f.name) + fsfetcher = FileSystemDagFetcher(dagbag) + found_dags = fsfetcher.process_file(f.name) return (dagbag, found_dags, f.name) def validate_dags(self, expected_parent_dag, actual_found_dags, actual_dagbag, @@ -1301,13 +1301,14 @@ def subdag_1(): self.validate_dags(testDag, found_dags, dagbag, should_be_found=False) self.assertIn(file_path, dagbag.import_errors) - def test_process_file_with_none(self): + def test_process_local_file_with_none(self): """ test that process_file can handle Nones """ dagbag = models.DagBag(include_examples=True) + fsfetcher = FileSystemDagFetcher(dagbag) - self.assertEqual([], dagbag.process_file(None)) + self.assertEqual([], fsfetcher.process_file(None)) class TaskInstanceTest(unittest.TestCase): diff --git a/tests/plugins/test_plugin.py b/tests/plugins/test_plugin.py index 49325e68b7052..fac674b2140e6 100644 --- a/tests/plugins/test_plugin.py +++ b/tests/plugins/test_plugin.py @@ -24,11 +24,14 @@ from airflow.models import BaseOperator from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.executors.base_executor import BaseExecutor +from airflow.dag.fetchers.base import BaseDagFetcher + # Will show up under airflow.hooks.test_plugin.PluginHook class PluginHook(BaseHook): pass + # Will show up under airflow.operators.test_plugin.PluginOperator class PluginOperator(BaseOperator): pass @@ -43,10 +46,17 @@ class PluginSensorOperator(BaseSensorOperator): class PluginExecutor(BaseExecutor): pass + # Will show up under airflow.macros.test_plugin.plugin_macro def plugin_macro(): pass + +# Will show up under airflow.dag.fetchers.test_plugin.PluginDagFetcher +class PluginDagFetcher(BaseDagFetcher): + pass + + # Creating a flask admin BaseView class TestView(BaseView): @expose('/') @@ -76,6 +86,7 @@ class AirflowTestPlugin(AirflowPlugin): hooks = [PluginHook] executors = [PluginExecutor] macros = [plugin_macro] + dag_fetchers = [PluginDagFetcher] admin_views = [v] flask_blueprints = [bp] menu_links = [ml] diff --git a/tests/plugins_manager.py b/tests/plugins_manager.py index a00d476f03852..9b76bf7d3112c 100644 --- a/tests/plugins_manager.py +++ b/tests/plugins_manager.py @@ -26,9 +26,10 @@ from flask_admin.menu import MenuLink, MenuView from airflow.hooks.base_hook import BaseHook -from airflow.models import BaseOperator +from airflow.models import BaseOperator from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.executors.base_executor import BaseExecutor +from airflow.dag.fetchers.base import BaseDagFetcher from airflow.www.app import cached_app @@ -62,6 +63,10 @@ def test_macros(self): from airflow.macros.test_plugin import plugin_macro self.assertTrue(callable(plugin_macro)) + def test_fetchers(self): + from airflow.dag.fetchers.test_plugin import PluginDagFetcher + self.assertTrue(issubclass(PluginDagFetcher, BaseDagFetcher)) + def test_admin_views(self): app = cached_app() [admin] = app.extensions['admin']