diff --git a/README.md b/README.md index 64740ca27a361..0f0d6c79ff8b4 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,7 @@ Currently **officially** using Airflow: 1. [Glassdoor](https://github.com/Glassdoor) [[@syvineckruyk](https://github.com/syvineckruyk)] 1. [Global Fashion Group](http://global-fashion-group.com) [[@GFG](https://github.com/GFG)] 1. [GovTech GDS](https://gds-gov.tech) [[@chrissng](https://github.com/chrissng) & [@datagovsg](https://github.com/datagovsg)] +1. [Grab](https://www.grab.com/sg/) [[@grab](https://github.com/grab)] 1. [Gradeup](https://gradeup.co) [[@gradeup](https://github.com/gradeup)] 1. [Grand Rounds](https://www.grandrounds.com/) [[@richddr](https://github.com/richddr), [@timz1290](https://github.com/timz1290), [@wenever](https://github.com/@wenever), & [@runongirlrunon](https://github.com/runongirlrunon)] 1. [Groupalia](http://es.groupalia.com) [[@jesusfcr](https://github.com/jesusfcr)] diff --git a/airflow/configuration.py b/airflow/configuration.py index 9e8584a6855c5..6065a2bc61b96 100644 --- a/airflow/configuration.py +++ b/airflow/configuration.py @@ -158,9 +158,9 @@ class AirflowConfigParser(ConfigParser): def __init__(self, default_config=None, *args, **kwargs): super(AirflowConfigParser, self).__init__(*args, **kwargs) - self.defaults = ConfigParser(*args, **kwargs) + self.airflow_defaults = ConfigParser(*args, **kwargs) if default_config is not None: - self.defaults.read_string(default_config) + self.airflow_defaults.read_string(default_config) self.is_validated = False @@ -250,9 +250,9 @@ def get(self, section, key, **kwargs): return option # ...then the default config - if self.defaults.has_option(section, key): + if self.airflow_defaults.has_option(section, key): return expand_env_var( - self.defaults.get(section, key, **kwargs)) + self.airflow_defaults.get(section, key, **kwargs)) else: log.warning( @@ -308,8 +308,8 @@ def remove_option(self, section, option, remove_default=True): if super(AirflowConfigParser, self).has_option(section, option): super(AirflowConfigParser, self).remove_option(section, option) - if self.defaults.has_option(section, option) and remove_default: - self.defaults.remove_option(section, option) + if self.airflow_defaults.has_option(section, option) and remove_default: + self.airflow_defaults.remove_option(section, option) def getsection(self, section): """ @@ -318,10 +318,11 @@ def getsection(self, section): :param section: section from the config :return: dict """ - if section not in self._sections and section not in self.defaults._sections: + if (section not in self._sections and + section not in self.airflow_defaults._sections): return None - _section = copy.deepcopy(self.defaults._sections[section]) + _section = copy.deepcopy(self.airflow_defaults._sections[section]) if section in self._sections: _section.update(copy.deepcopy(self._sections[section])) @@ -340,30 +341,35 @@ def getsection(self, section): _section[key] = val return _section - def as_dict(self, display_source=False, display_sensitive=False): + def as_dict( + self, display_source=False, display_sensitive=False, raw=False): """ Returns the current configuration as an OrderedDict of OrderedDicts. :param display_source: If False, the option value is returned. If True, a tuple of (option_value, source) is returned. Source is either - 'airflow.cfg' or 'default'. + 'airflow.cfg', 'default', 'env var', or 'cmd'. :type display_source: bool :param display_sensitive: If True, the values of options set by env vars and bash commands will be displayed. If False, those options are shown as '< hidden >' :type display_sensitive: bool + :param raw: Should the values be output as interpolated values, or the + "raw" form that can be fed back in to ConfigParser + :type raw: bool """ - cfg = copy.deepcopy(self.defaults._sections) - cfg.update(copy.deepcopy(self._sections)) - - # remove __name__ (affects Python 2 only) - for options in cfg.values(): - options.pop('__name__', None) - - # add source - if display_source: - for section in cfg: - for k, v in cfg[section].items(): - cfg[section][k] = (v, 'airflow config') + cfg = {} + configs = [ + ('default', self.airflow_defaults), + ('airflow.cfg', self), + ] + + for (source_name, config) in configs: + for section in config.sections(): + sect = cfg.setdefault(section, OrderedDict()) + for (k, val) in config.items(section=section, raw=raw): + if display_source: + val = (val, source_name) + sect[k] = val # add env vars and overwrite because they have priority for ev in [ev for ev in os.environ if ev.startswith('AIRFLOW__')]: @@ -371,16 +377,15 @@ def as_dict(self, display_source=False, display_sensitive=False): _, section, key = ev.split('__') opt = self._get_env_var_option(section, key) except ValueError: - opt = None - if opt: - if ( - not display_sensitive and - ev != 'AIRFLOW__CORE__UNIT_TEST_MODE'): - opt = '< hidden >' - if display_source: - opt = (opt, 'env var') - cfg.setdefault(section.lower(), OrderedDict()).update( - {key.lower(): opt}) + continue + if (not display_sensitive and ev != 'AIRFLOW__CORE__UNIT_TEST_MODE'): + opt = '< hidden >' + elif raw: + opt = opt.replace('%', '%%') + if display_source: + opt = (opt, 'env var') + cfg.setdefault(section.lower(), OrderedDict()).update( + {key.lower(): opt}) # add bash commands for (section, key) in self.as_command_stdout: @@ -389,8 +394,11 @@ def as_dict(self, display_source=False, display_sensitive=False): if not display_sensitive: opt = '< hidden >' if display_source: - opt = (opt, 'bash cmd') + opt = (opt, 'cmd') + elif raw: + opt = opt.replace('%', '%%') cfg.setdefault(section, OrderedDict()).update({key: opt}) + del cfg[section][key + '_cmd'] return cfg diff --git a/airflow/contrib/auth/backends/google_auth.py b/airflow/contrib/auth/backends/google_auth.py index bc7d552f59e93..ddbcb1222f2c0 100644 --- a/airflow/contrib/auth/backends/google_auth.py +++ b/airflow/contrib/auth/backends/google_auth.py @@ -112,8 +112,7 @@ def login(self, request): log.debug('Redirecting user to Google login') return self.google_oauth.authorize(callback=url_for( 'google_oauth_callback', - _external=True, - _scheme='https'), + _external=True), state=request.args.get('next') or request.referrer or None) def get_google_user_profile_info(self, google_token): diff --git a/airflow/contrib/hooks/aws_hook.py b/airflow/contrib/hooks/aws_hook.py index 448de63ffe989..8ce74d2b4eedc 100644 --- a/airflow/contrib/hooks/aws_hook.py +++ b/airflow/contrib/hooks/aws_hook.py @@ -97,33 +97,36 @@ def _get_credentials(self, region_name): if self.aws_conn_id: try: connection_object = self.get_connection(self.aws_conn_id) + extra_config = connection_object.extra_dejson if connection_object.login: aws_access_key_id = connection_object.login aws_secret_access_key = connection_object.password - elif 'aws_secret_access_key' in connection_object.extra_dejson: - aws_access_key_id = connection_object.extra_dejson[ + elif 'aws_secret_access_key' in extra_config: + aws_access_key_id = extra_config[ 'aws_access_key_id'] - aws_secret_access_key = connection_object.extra_dejson[ + aws_secret_access_key = extra_config[ 'aws_secret_access_key'] - elif 's3_config_file' in connection_object.extra_dejson: + elif 's3_config_file' in extra_config: aws_access_key_id, aws_secret_access_key = \ _parse_s3_config( - connection_object.extra_dejson['s3_config_file'], - connection_object.extra_dejson.get('s3_config_format')) + extra_config['s3_config_file'], + extra_config.get('s3_config_format'), + extra_config.get('profile')) if region_name is None: - region_name = connection_object.extra_dejson.get('region_name') + region_name = extra_config.get('region_name') - role_arn = connection_object.extra_dejson.get('role_arn') - external_id = connection_object.extra_dejson.get('external_id') - aws_account_id = connection_object.extra_dejson.get('aws_account_id') - aws_iam_role = connection_object.extra_dejson.get('aws_iam_role') + role_arn = extra_config.get('role_arn') + external_id = extra_config.get('external_id') + aws_account_id = extra_config.get('aws_account_id') + aws_iam_role = extra_config.get('aws_iam_role') if role_arn is None and aws_account_id is not None and \ aws_iam_role is not None: - role_arn = "arn:aws:iam::" + aws_account_id + ":role/" + aws_iam_role + role_arn = "arn:aws:iam::{}:role/{}" \ + .format(aws_account_id, aws_iam_role) if role_arn is not None: sts_session = boto3.session.Session( @@ -143,11 +146,12 @@ def _get_credentials(self, region_name): RoleSessionName='Airflow_' + self.aws_conn_id, ExternalId=external_id) - aws_access_key_id = sts_response['Credentials']['AccessKeyId'] - aws_secret_access_key = sts_response['Credentials']['SecretAccessKey'] - aws_session_token = sts_response['Credentials']['SessionToken'] + credentials = sts_response['Credentials'] + aws_access_key_id = credentials['AccessKeyId'] + aws_secret_access_key = credentials['SecretAccessKey'] + aws_session_token = credentials['SessionToken'] - endpoint_url = connection_object.extra_dejson.get('host') + endpoint_url = extra_config.get('host') except AirflowException: # No connection found: fallback on boto3 credential strategy @@ -183,7 +187,7 @@ def get_credentials(self, region_name=None): This contains the attributes: access_key, secret_key and token. """ session, _ = self._get_credentials(region_name) - # Credentials are refreshable, so accessing your access key / secret key - # separately can lead to a race condition. + # Credentials are refreshable, so accessing your access key and + # secret key separately can lead to a race condition. # See https://stackoverflow.com/a/36291428/8283373 return session.get_credentials().get_frozen_credentials() diff --git a/airflow/contrib/hooks/wasb_hook.py b/airflow/contrib/hooks/wasb_hook.py index 130c19469bd08..a977a98a33740 100644 --- a/airflow/contrib/hooks/wasb_hook.py +++ b/airflow/contrib/hooks/wasb_hook.py @@ -18,6 +18,7 @@ # under the License. # +from airflow import AirflowException from airflow.hooks.base_hook import BaseHook from azure.storage.blob import BlockBlobService @@ -148,3 +149,43 @@ def read_file(self, container_name, blob_name, **kwargs): return self.connection.get_blob_to_text(container_name, blob_name, **kwargs).content + + def delete_file(self, container_name, blob_name, is_prefix=False, + ignore_if_missing=False, **kwargs): + """ + Delete a file from Azure Blob Storage. + + :param container_name: Name of the container. + :type container_name: str + :param blob_name: Name of the blob. + :type blob_name: str + :param is_prefix: If blob_name is a prefix, delete all matching files + :type is_prefix: bool + :param ignore_if_missing: if True, then return success even if the + blob does not exist. + :type ignore_if_missing: bool + :param kwargs: Optional keyword arguments that + `BlockBlobService.create_blob_from_path()` takes. + :type kwargs: object + """ + + if is_prefix: + blobs_to_delete = [ + blob.name for blob in self.connection.list_blobs( + container_name, prefix=blob_name, **kwargs + ) + ] + elif self.check_for_blob(container_name, blob_name): + blobs_to_delete = [blob_name] + else: + blobs_to_delete = [] + + if not ignore_if_missing and len(blobs_to_delete) == 0: + raise AirflowException('Blob(s) not found: {}'.format(blob_name)) + + for blob_uri in blobs_to_delete: + self.log.info("Deleting blob: " + blob_uri) + self.connection.delete_blob(container_name, + blob_uri, + delete_snapshots='include', + **kwargs) diff --git a/airflow/contrib/operators/wasb_delete_blob_operator.py b/airflow/contrib/operators/wasb_delete_blob_operator.py new file mode 100644 index 0000000000000..4634741d8b824 --- /dev/null +++ b/airflow/contrib/operators/wasb_delete_blob_operator.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from airflow.contrib.hooks.wasb_hook import WasbHook +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + + +class WasbDeleteBlobOperator(BaseOperator): + """ + Deletes blob(s) on Azure Blob Storage. + + :param container_name: Name of the container. (templated) + :type container_name: str + :param blob_name: Name of the blob. (templated) + :type blob_name: str + :param wasb_conn_id: Reference to the wasb connection. + :type wasb_conn_id: str + :param check_options: Optional keyword arguments that + `WasbHook.check_for_blob()` takes. + :param is_prefix: If blob_name is a prefix, delete all files matching prefix. + :type is_prefix: bool + :param ignore_if_missing: if True, then return success even if the + blob does not exist. + :type ignore_if_missing: bool + """ + + template_fields = ('container_name', 'blob_name') + + @apply_defaults + def __init__(self, container_name, blob_name, + wasb_conn_id='wasb_default', check_options=None, + is_prefix=False, ignore_if_missing=False, + *args, + **kwargs): + super(WasbDeleteBlobOperator, self).__init__(*args, **kwargs) + if check_options is None: + check_options = {} + self.wasb_conn_id = wasb_conn_id + self.container_name = container_name + self.blob_name = blob_name + self.check_options = check_options + self.is_prefix = is_prefix + self.ignore_if_missing = ignore_if_missing + + def execute(self, context): + self.log.info( + 'Deleting blob: {self.blob_name}\n' + 'in wasb://{self.container_name}'.format(**locals()) + ) + hook = WasbHook(wasb_conn_id=self.wasb_conn_id) + + hook.delete_file(self.container_name, self.blob_name, + self.is_prefix, self.ignore_if_missing, + **self.check_options) diff --git a/airflow/migrations/versions/bf00311e1990_add_index_to_taskinstance.py b/airflow/migrations/versions/bf00311e1990_add_index_to_taskinstance.py new file mode 100644 index 0000000000000..528bd53b366e5 --- /dev/null +++ b/airflow/migrations/versions/bf00311e1990_add_index_to_taskinstance.py @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""add index to taskinstance + +Revision ID: bf00311e1990 +Revises: dd25f486b8ea +Create Date: 2018-09-12 09:53:52.007433 + +""" + +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'bf00311e1990' +down_revision = 'dd25f486b8ea' +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_index( + 'ti_dag_date', + 'task_instance', + ['dag_id', 'execution_date'], + unique=False + ) + + +def downgrade(): + op.drop_index('ti_dag_date', table_name='task_instance') diff --git a/airflow/models.py b/airflow/models.py index 624d35eb45a7f..e54c04b284cf6 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -887,6 +887,7 @@ class TaskInstance(Base, LoggingMixin): __table_args__ = ( Index('ti_dag_state', dag_id, state), + Index('ti_dag_date', dag_id, execution_date), Index('ti_state', state), Index('ti_state_lkp', dag_id, task_id, execution_date, state), Index('ti_pool', pool, state, priority_weight), diff --git a/airflow/task/task_runner/base_task_runner.py b/airflow/task/task_runner/base_task_runner.py index 0b195047cb170..2a346de939e73 100644 --- a/airflow/task/task_runner/base_task_runner.py +++ b/airflow/task/task_runner/base_task_runner.py @@ -60,12 +60,6 @@ def __init__(self, local_task_job): # Always provide a copy of the configuration file settings cfg_path = tmp_configuration_copy() - # The following command should always work since the user doing chmod is the same - # as the one who just created the file. - subprocess.call( - ['chmod', '600', cfg_path], - close_fds=True - ) # Add sudo commands to change user if we need to. Needed to handle SubDagOperator # case using a SequentialExecutor. diff --git a/airflow/utils/configuration.py b/airflow/utils/configuration.py index 18a338c23f6ff..6a621d5fa9c18 100644 --- a/airflow/utils/configuration.py +++ b/airflow/utils/configuration.py @@ -26,16 +26,18 @@ from airflow import configuration as conf -def tmp_configuration_copy(): +def tmp_configuration_copy(chmod=0o600): """ Returns a path for a temporary file including a full copy of the configuration settings. :return: a path to a temporary file """ - cfg_dict = conf.as_dict(display_sensitive=True) + cfg_dict = conf.as_dict(display_sensitive=True, raw=True) temp_fd, cfg_path = mkstemp() with os.fdopen(temp_fd, 'w') as temp_file: + if chmod is not None: + os.fchmod(temp_fd, chmod) json.dump(cfg_dict, temp_file) return cfg_path diff --git a/docs/security.rst b/docs/security.rst index 76a11418a2e25..c14cd1c2c3393 100644 --- a/docs/security.rst +++ b/docs/security.rst @@ -12,9 +12,10 @@ Be sure to checkout :doc:`api` for securing the API. .. note:: - Airflow uses the config parser of Python. This config parser interpolates '%'-signs. - Make sure not to have those in your passwords if they do not make sense, otherwise - Airflow might leak these passwords on a config parser exception to a log. + Airflow uses the config parser of Python. This config parser interpolates + '%'-signs. Make sure escape any ``%`` signs in your config file (but not + environment variables) as ``%%``, otherwise Airflow might leak these + passwords on a config parser exception to a log. Web Authentication ------------------ diff --git a/run_unit_tests.sh b/run_unit_tests.sh index 8d671583427eb..2c4abbfaad605 100755 --- a/run_unit_tests.sh +++ b/run_unit_tests.sh @@ -24,9 +24,6 @@ set -x export AIRFLOW_HOME=${AIRFLOW_HOME:=~} export AIRFLOW__CORE__UNIT_TEST_MODE=True -# configuration test -export AIRFLOW__TESTSECTION__TESTKEY=testvalue - # add test/contrib to PYTHONPATH DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" export PYTHONPATH=$PYTHONPATH:${DIR}/tests/test_utils diff --git a/scripts/ci/5-run-tests.sh b/scripts/ci/5-run-tests.sh index 8acc78f1a7ff0..8a74d824efa29 100755 --- a/scripts/ci/5-run-tests.sh +++ b/scripts/ci/5-run-tests.sh @@ -45,9 +45,6 @@ echo Backend: $AIRFLOW__CORE__SQL_ALCHEMY_CONN export AIRFLOW_HOME=${AIRFLOW_HOME:=~} export AIRFLOW__CORE__UNIT_TEST_MODE=True -# configuration test -export AIRFLOW__TESTSECTION__TESTKEY=testvalue - # any argument received is overriding the default nose execution arguments: nose_args=$@ diff --git a/setup.py b/setup.py index 76f55ab01b929..733e22f3b5e46 100644 --- a/setup.py +++ b/setup.py @@ -175,7 +175,7 @@ def write_version(filename=os.path.join(*['airflow', 'sphinx-rtd-theme>=0.1.6', 'Sphinx-PyPI-upload>=0.2.1' ] -docker = ['docker>=2.0.0'] +docker = ['docker>=2.0.0,<3.0.0'] druid = ['pydruid>=0.4.1'] elasticsearch = [ 'elasticsearch>=5.0.0,<6.0.0', diff --git a/tests/configuration.py b/tests/configuration.py index e94491ee4e0b7..09284c9972e44 100644 --- a/tests/configuration.py +++ b/tests/configuration.py @@ -37,38 +37,64 @@ class ConfTest(unittest.TestCase): - def setup(self): + @classmethod + def setUpClass(cls): + os.environ['AIRFLOW__TESTSECTION__TESTKEY'] = 'testvalue' + os.environ['AIRFLOW__TESTSECTION__TESTPERCENT'] = 'with%percent' configuration.load_test_config() + conf.set('core', 'percent', 'with%%inside') + + @classmethod + def tearDownClass(cls): + del os.environ['AIRFLOW__TESTSECTION__TESTKEY'] + del os.environ['AIRFLOW__TESTSECTION__TESTPERCENT'] def test_env_var_config(self): opt = conf.get('testsection', 'testkey') self.assertEqual(opt, 'testvalue') + opt = conf.get('testsection', 'testpercent') + self.assertEqual(opt, 'with%percent') + def test_conf_as_dict(self): cfg_dict = conf.as_dict() # test that configs are picked up self.assertEqual(cfg_dict['core']['unit_test_mode'], 'True') + self.assertEqual(cfg_dict['core']['percent'], 'with%inside') + # test env vars self.assertEqual(cfg_dict['testsection']['testkey'], '< hidden >') + def test_conf_as_dict_source(self): # test display_source cfg_dict = conf.as_dict(display_source=True) self.assertEqual( - cfg_dict['core']['load_examples'][1], 'airflow config') + cfg_dict['core']['load_examples'][1], 'airflow.cfg') self.assertEqual( cfg_dict['testsection']['testkey'], ('< hidden >', 'env var')) + def test_conf_as_dict_sensitive(self): # test display_sensitive cfg_dict = conf.as_dict(display_sensitive=True) self.assertEqual(cfg_dict['testsection']['testkey'], 'testvalue') + self.assertEqual(cfg_dict['testsection']['testpercent'], 'with%percent') # test display_source and display_sensitive cfg_dict = conf.as_dict(display_sensitive=True, display_source=True) self.assertEqual( cfg_dict['testsection']['testkey'], ('testvalue', 'env var')) + def test_conf_as_dict_raw(self): + # test display_sensitive + cfg_dict = conf.as_dict(raw=True, display_sensitive=True) + self.assertEqual(cfg_dict['testsection']['testkey'], 'testvalue') + + # Values with '%' in them should be escaped + self.assertEqual(cfg_dict['testsection']['testpercent'], 'with%%percent') + self.assertEqual(cfg_dict['core']['percent'], 'with%%inside') + def test_command_config(self): TEST_CONFIG = '''[test] key1 = hello @@ -104,6 +130,10 @@ def test_command_config(self): self.assertFalse(test_conf.has_option('test', 'key5')) self.assertTrue(test_conf.has_option('another', 'key6')) + cfg_dict = test_conf.as_dict(display_sensitive=True) + self.assertEqual('cmd_result', cfg_dict['test']['key2']) + self.assertNotIn('key2_cmd', cfg_dict['test']) + def test_remove_option(self): TEST_CONFIG = '''[test] key1 = hello diff --git a/tests/contrib/hooks/test_aws_hook.py b/tests/contrib/hooks/test_aws_hook.py index d7664aca1137a..eaadc5fbff413 100644 --- a/tests/contrib/hooks/test_aws_hook.py +++ b/tests/contrib/hooks/test_aws_hook.py @@ -19,6 +19,7 @@ # import unittest + import boto3 from airflow import configuration @@ -146,6 +147,26 @@ def test_get_credentials_from_extra(self, mock_get_connection): self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key') self.assertIsNone(credentials_from_hook.token) + @mock.patch('airflow.contrib.hooks.aws_hook._parse_s3_config', + return_value=('aws_access_key_id', 'aws_secret_access_key')) + @mock.patch.object(AwsHook, 'get_connection') + def test_get_credentials_from_extra_with_s3_config_and_profile( + self, mock_get_connection, mock_parse_s3_config + ): + mock_connection = Connection( + extra='{"s3_config_format": "aws", ' + '"profile": "test", ' + '"s3_config_file": "aws-credentials", ' + '"region_name": "us-east-1"}') + mock_get_connection.return_value = mock_connection + hook = AwsHook() + hook._get_credentials(region_name=None) + mock_parse_s3_config.assert_called_with( + 'aws-credentials', + 'aws', + 'test' + ) + @unittest.skipIf(mock_sts is None, 'mock_sts package not present') @mock.patch.object(AwsHook, 'get_connection') @mock_sts diff --git a/tests/contrib/hooks/test_wasb_hook.py b/tests/contrib/hooks/test_wasb_hook.py index b5545e2727152..88481440e71be 100644 --- a/tests/contrib/hooks/test_wasb_hook.py +++ b/tests/contrib/hooks/test_wasb_hook.py @@ -21,8 +21,9 @@ import json import unittest +from collections import namedtuple -from airflow import configuration +from airflow import configuration, AirflowException from airflow import models from airflow.contrib.hooks.wasb_hook import WasbHook from airflow.utils import db @@ -143,6 +144,59 @@ def test_read_file(self, mock_service): 'container', 'blob', max_connections=1 ) + @mock.patch('airflow.contrib.hooks.wasb_hook.BlockBlobService', + autospec=True) + def test_delete_single_blob(self, mock_service): + mock_instance = mock_service.return_value + hook = WasbHook(wasb_conn_id='wasb_test_sas_token') + hook.delete_file('container', 'blob', is_prefix=False) + mock_instance.delete_blob.assert_called_once_with( + 'container', 'blob', delete_snapshots='include' + ) + + @mock.patch('airflow.contrib.hooks.wasb_hook.BlockBlobService', + autospec=True) + def test_delete_multiple_blobs(self, mock_service): + mock_instance = mock_service.return_value + Blob = namedtuple('Blob', ['name']) + mock_instance.list_blobs.return_value = iter( + [Blob('blob_prefix/blob1'), Blob('blob_prefix/blob2')] + ) + hook = WasbHook(wasb_conn_id='wasb_test_sas_token') + hook.delete_file('container', 'blob_prefix', is_prefix=True) + mock_instance.delete_blob.assert_any_call( + 'container', 'blob_prefix/blob1', delete_snapshots='include' + ) + mock_instance.delete_blob.assert_any_call( + 'container', 'blob_prefix/blob2', delete_snapshots='include' + ) + + @mock.patch('airflow.contrib.hooks.wasb_hook.BlockBlobService', + autospec=True) + def test_delete_nonexisting_blob_fails(self, mock_service): + mock_instance = mock_service.return_value + mock_instance.exists.return_value = False + hook = WasbHook(wasb_conn_id='wasb_test_sas_token') + with self.assertRaises(Exception) as context: + hook.delete_file( + 'container', 'nonexisting_blob', + is_prefix=False, ignore_if_missing=False + ) + self.assertIsInstance(context.exception, AirflowException) + + @mock.patch('airflow.contrib.hooks.wasb_hook.BlockBlobService', + autospec=True) + def test_delete_multiple_nonexisting_blobs_fails(self, mock_service): + mock_instance = mock_service.return_value + mock_instance.list_blobs.return_value = iter([]) + hook = WasbHook(wasb_conn_id='wasb_test_sas_token') + with self.assertRaises(Exception) as context: + hook.delete_file( + 'container', 'nonexisting_blob_prefix', + is_prefix=True, ignore_if_missing=False + ) + self.assertIsInstance(context.exception, AirflowException) + if __name__ == '__main__': unittest.main() diff --git a/tests/contrib/operators/test_wasb_delete_blob_operator.py b/tests/contrib/operators/test_wasb_delete_blob_operator.py new file mode 100644 index 0000000000000..7c3ed0bd630ad --- /dev/null +++ b/tests/contrib/operators/test_wasb_delete_blob_operator.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import datetime +import unittest + +from airflow import DAG, configuration +from airflow.contrib.operators.wasb_delete_blob_operator import WasbDeleteBlobOperator + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + + +class TestWasbDeleteBlobOperator(unittest.TestCase): + + _config = { + 'container_name': 'container', + 'blob_name': 'blob', + } + + def setUp(self): + configuration.load_test_config() + args = { + 'owner': 'airflow', + 'start_date': datetime.datetime(2017, 1, 1) + } + self.dag = DAG('test_dag_id', default_args=args) + + def test_init(self): + operator = WasbDeleteBlobOperator( + task_id='wasb_operator', + dag=self.dag, + **self._config + ) + self.assertEqual(operator.container_name, + self._config['container_name']) + self.assertEqual(operator.blob_name, self._config['blob_name']) + self.assertEqual(operator.is_prefix, False) + self.assertEqual(operator.ignore_if_missing, False) + + operator = WasbDeleteBlobOperator( + task_id='wasb_operator', + dag=self.dag, + is_prefix=True, + ignore_if_missing=True, + **self._config + ) + self.assertEqual(operator.is_prefix, True) + self.assertEqual(operator.ignore_if_missing, True) + + @mock.patch('airflow.contrib.operators.wasb_delete_blob_operator.WasbHook', + autospec=True) + def test_execute(self, mock_hook): + mock_instance = mock_hook.return_value + operator = WasbDeleteBlobOperator( + task_id='wasb_operator', + dag=self.dag, + is_prefix=True, + ignore_if_missing=True, + **self._config + ) + operator.execute(None) + mock_instance.delete_file.assert_called_once_with( + 'container', 'blob', True, True + ) + + +if __name__ == '__main__': + unittest.main()