diff --git a/tests/providers/apache/hdfs/sensors/test_webhdfs.py b/tests/providers/apache/hdfs/sensors/test_webhdfs.py new file mode 100644 index 0000000000000..1988071798d0f --- /dev/null +++ b/tests/providers/apache/hdfs/sensors/test_webhdfs.py @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import unittest + +from airflow.providers.apache.hdfs.sensors.web_hdfs import WebHdfsSensor +from tests.providers.apache.hive import DEFAULT_DATE, TestHiveEnvironment + + +@unittest.skipIf( + 'AIRFLOW_RUNALL_TESTS' not in os.environ, + "Skipped because AIRFLOW_RUNALL_TESTS is not set") +class TestWebHdfsSensor(TestHiveEnvironment): + + def test_webhdfs_sensor(self): + op = WebHdfsSensor( + task_id='webhdfs_sensor_check', + filepath='hdfs://user/hive/warehouse/airflow.db/static_babynames', + timeout=120, + dag=self.dag) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, + ignore_ti_state=True) diff --git a/tests/providers/apache/hive/__init__.py b/tests/providers/apache/hive/__init__.py index 217e5db960782..dab016ca7a65c 100644 --- a/tests/providers/apache/hive/__init__.py +++ b/tests/providers/apache/hive/__init__.py @@ -15,3 +15,34 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from datetime import datetime +from unittest import TestCase + +from airflow import DAG + +DEFAULT_DATE = datetime(2015, 1, 1) +DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat() +DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10] + + +class TestHiveEnvironment(TestCase): + + def setUp(self): + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} + dag = DAG('test_dag_id', default_args=args) + self.dag = dag + self.hql = """ + USE airflow; + DROP TABLE IF EXISTS static_babynames_partitioned; + CREATE TABLE IF NOT EXISTS static_babynames_partitioned ( + state string, + year string, + name string, + gender string, + num int) + PARTITIONED BY (ds string); + INSERT OVERWRITE TABLE static_babynames_partitioned + PARTITION(ds='{{ ds }}') + SELECT state, year, name, gender, num FROM static_babynames; + """ diff --git a/tests/providers/apache/hive/operators/test_hive.py b/tests/providers/apache/hive/operators/test_hive.py index dc3582730763c..5b64cb59b2278 100644 --- a/tests/providers/apache/hive/operators/test_hive.py +++ b/tests/providers/apache/hive/operators/test_hive.py @@ -16,54 +16,15 @@ # specific language governing permissions and limitations # under the License. -import datetime import os import unittest from unittest import mock -from airflow import DAG from airflow.configuration import conf -from airflow.exceptions import AirflowSensorTimeout from airflow.models import TaskInstance -from airflow.providers.apache.hdfs.sensors.hdfs import HdfsSensor -from airflow.providers.apache.hdfs.sensors.web_hdfs import WebHdfsSensor from airflow.providers.apache.hive.operators.hive import HiveOperator -from airflow.providers.apache.hive.operators.hive_stats import HiveStatsCollectionOperator -from airflow.providers.apache.hive.operators.hive_to_mysql import HiveToMySqlTransfer -from airflow.providers.apache.hive.operators.hive_to_samba import Hive2SambaOperator -from airflow.providers.apache.hive.sensors.hive_partition import HivePartitionSensor -from airflow.providers.apache.hive.sensors.metastore_partition import MetastorePartitionSensor -from airflow.providers.apache.hive.sensors.named_hive_partition import NamedHivePartitionSensor -from airflow.providers.mysql.operators.presto_to_mysql import PrestoToMySqlTransfer -from airflow.providers.presto.operators.presto_check import PrestoCheckOperator -from airflow.sensors.sql_sensor import SqlSensor from airflow.utils import timezone - -DEFAULT_DATE = datetime.datetime(2015, 1, 1) -DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat() -DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10] - - -class TestHiveEnvironment(unittest.TestCase): - - def setUp(self): - args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} - dag = DAG('test_dag_id', default_args=args) - self.dag = dag - self.hql = """ - USE airflow; - DROP TABLE IF EXISTS static_babynames_partitioned; - CREATE TABLE IF NOT EXISTS static_babynames_partitioned ( - state string, - year string, - name string, - gender string, - num int) - PARTITIONED BY (ds string); - INSERT OVERWRITE TABLE static_babynames_partitioned - PARTITION(ds='{{ ds }}') - SELECT state, year, name, gender, num FROM static_babynames; - """ +from tests.providers.apache.hive import DEFAULT_DATE, TestHiveEnvironment class HiveOperatorConfigTest(TestHiveEnvironment): @@ -168,150 +129,3 @@ def test_beeline(self): hql=self.hql, dag=self.dag) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - - def test_presto(self): - sql = """ - SELECT count(1) FROM airflow.static_babynames_partitioned; - """ - op = PrestoCheckOperator( - task_id='presto_check', sql=sql, dag=self.dag) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) - - def test_presto_to_mysql(self): - op = PrestoToMySqlTransfer( - task_id='presto_to_mysql_check', - sql=""" - SELECT name, count(*) as ccount - FROM airflow.static_babynames - GROUP BY name - """, - mysql_table='test_static_babynames', - mysql_preoperator='TRUNCATE TABLE test_static_babynames;', - dag=self.dag) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) - - def test_hdfs_sensor(self): - op = HdfsSensor( - task_id='hdfs_sensor_check', - filepath='hdfs://user/hive/warehouse/airflow.db/static_babynames', - dag=self.dag) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) - - def test_webhdfs_sensor(self): - op = WebHdfsSensor( - task_id='webhdfs_sensor_check', - filepath='hdfs://user/hive/warehouse/airflow.db/static_babynames', - timeout=120, - dag=self.dag) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) - - def test_sql_sensor(self): - op = SqlSensor( - task_id='hdfs_sensor_check', - conn_id='presto_default', - sql="SELECT 'x' FROM airflow.static_babynames LIMIT 1;", - dag=self.dag) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) - - def test_hive_stats(self): - op = HiveStatsCollectionOperator( - task_id='hive_stats_check', - table="airflow.static_babynames_partitioned", - partition={'ds': DEFAULT_DATE_DS}, - dag=self.dag) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) - - def test_named_hive_partition_sensor(self): - op = NamedHivePartitionSensor( - task_id='hive_partition_check', - partition_names=[ - "airflow.static_babynames_partitioned/ds={{ds}}" - ], - dag=self.dag) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) - - def test_named_hive_partition_sensor_succeeds_on_multiple_partitions(self): - op = NamedHivePartitionSensor( - task_id='hive_partition_check', - partition_names=[ - "airflow.static_babynames_partitioned/ds={{ds}}", - "airflow.static_babynames_partitioned/ds={{ds}}" - ], - dag=self.dag) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) - - def test_named_hive_partition_sensor_parses_partitions_with_periods(self): - name = NamedHivePartitionSensor.parse_partition_name( - partition="schema.table/part1=this.can.be.an.issue/part2=ok") - self.assertEqual(name[0], "schema") - self.assertEqual(name[1], "table") - self.assertEqual(name[2], "part1=this.can.be.an.issue/part2=this_should_be_ok") - - def test_named_hive_partition_sensor_times_out_on_nonexistent_partition(self): - with self.assertRaises(AirflowSensorTimeout): - op = NamedHivePartitionSensor( - task_id='hive_partition_check', - partition_names=[ - "airflow.static_babynames_partitioned/ds={{ds}}", - "airflow.static_babynames_partitioned/ds=nonexistent" - ], - poke_interval=0.1, - timeout=1, - dag=self.dag) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) - - def test_hive_partition_sensor(self): - op = HivePartitionSensor( - task_id='hive_partition_check', - table='airflow.static_babynames_partitioned', - dag=self.dag) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) - - def test_hive_metastore_sql_sensor(self): - op = MetastorePartitionSensor( - task_id='hive_partition_check', - table='airflow.static_babynames_partitioned', - partition_name='ds={}'.format(DEFAULT_DATE_DS), - dag=self.dag) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) - - def test_hive2samba(self): - op = Hive2SambaOperator( - task_id='hive2samba_check', - samba_conn_id='tableau_samba', - hql="SELECT * FROM airflow.static_babynames LIMIT 10000", - destination_filepath='test_airflow.csv', - dag=self.dag) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) - - def test_hive_to_mysql(self): - op = HiveToMySqlTransfer( - mysql_conn_id='airflow_db', - task_id='hive_to_mysql_check', - create=True, - sql=""" - SELECT name - FROM airflow.static_babynames - LIMIT 100 - """, - mysql_table='test_static_babynames', - mysql_preoperator=[ - 'DROP TABLE IF EXISTS test_static_babynames;', - 'CREATE TABLE test_static_babynames (name VARCHAR(500))', - ], - dag=self.dag) - op.clear(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_ti_state=True) diff --git a/tests/providers/apache/hive/operators/test_hive_stats.py b/tests/providers/apache/hive/operators/test_hive_stats.py index e9f24d7a18b51..05b023a11d0ab 100644 --- a/tests/providers/apache/hive/operators/test_hive_stats.py +++ b/tests/providers/apache/hive/operators/test_hive_stats.py @@ -16,12 +16,14 @@ # specific language governing permissions and limitations # under the License. +import os import unittest from collections import OrderedDict from unittest.mock import patch from airflow import AirflowException from airflow.providers.apache.hive.operators.hive_stats import HiveStatsCollectionOperator +from tests.providers.apache.hive import DEFAULT_DATE, DEFAULT_DATE_DS, TestHiveEnvironment class _FakeCol: @@ -33,7 +35,7 @@ def __init__(self, col_name, col_type): fake_col = _FakeCol('col', 'string') -class TestHiveStatsCollectionOperator(unittest.TestCase): +class TestHiveStatsCollectionOperator(TestHiveEnvironment): def setUp(self): self.kwargs = dict( @@ -43,8 +45,8 @@ def setUp(self): presto_conn_id='presto_conn_id', mysql_conn_id='mysql_conn_id', task_id='test_hive_stats_collection_operator', - dag=None ) + super().setUp() def test_get_default_exprs(self): col = 'col' @@ -282,3 +284,15 @@ def test_execute_delete_previous_runs_rows(self, hive_stats_collection_operator.dttm ) mock_mysql_hook.return_value.run.assert_called_once_with(sql) + + @unittest.skipIf( + 'AIRFLOW_RUNALL_TESTS' not in os.environ, + "Skipped because AIRFLOW_RUNALL_TESTS is not set") + def test_runs_for_hive_stats(self): + op = HiveStatsCollectionOperator( + task_id='hive_stats_check', + table="airflow.static_babynames_partitioned", + partition={'ds': DEFAULT_DATE_DS}, + dag=self.dag) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, + ignore_ti_state=True) diff --git a/tests/providers/apache/hive/operators/test_hive_to_mysql.py b/tests/providers/apache/hive/operators/test_hive_to_mysql.py index bc275ee1dd5d4..437385a1eceb0 100644 --- a/tests/providers/apache/hive/operators/test_hive_to_mysql.py +++ b/tests/providers/apache/hive/operators/test_hive_to_mysql.py @@ -15,15 +15,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +import os import unittest from unittest.mock import PropertyMock, patch from airflow.providers.apache.hive.operators.hive_to_mysql import HiveToMySqlTransfer +from airflow.utils import timezone from airflow.utils.operator_helpers import context_to_airflow_vars +from tests.providers.apache.hive import TestHiveEnvironment + +DEFAULT_DATE = timezone.datetime(2015, 1, 1) -class TestHiveToMySqlTransfer(unittest.TestCase): +class TestHiveToMySqlTransfer(TestHiveEnvironment): def setUp(self): self.kwargs = dict( @@ -32,8 +36,8 @@ def setUp(self): hiveserver2_conn_id='hiveserver2_default', mysql_conn_id='mysql_default', task_id='test_hive_to_mysql', - dag=None ) + super().setUp() @patch('airflow.providers.apache.hive.operators.hive_to_mysql.MySqlHook') @patch('airflow.providers.apache.hive.operators.hive_to_mysql.HiveServer2Hook') @@ -105,3 +109,26 @@ def test_execute_with_hive_conf(self, mock_hive_hook, mock_mysql_hook): self.kwargs['sql'], hive_conf=hive_conf ) + + @unittest.skipIf( + 'AIRFLOW_RUNALL_TESTS' not in os.environ, + "Skipped because AIRFLOW_RUNALL_TESTS is not set") + def test_hive_to_mysql(self): + op = HiveToMySqlTransfer( + mysql_conn_id='airflow_db', + task_id='hive_to_mysql_check', + create=True, + sql=""" + SELECT name + FROM airflow.static_babynames + LIMIT 100 + """, + mysql_table='test_static_babynames', + mysql_preoperator=[ + 'DROP TABLE IF EXISTS test_static_babynames;', + 'CREATE TABLE test_static_babynames (name VARCHAR(500))', + ], + dag=self.dag) + op.clear(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, + ignore_ti_state=True) diff --git a/tests/providers/apache/hive/operators/test_hive_to_samba.py b/tests/providers/apache/hive/operators/test_hive_to_samba.py index 3addc49309480..58f88149af8c4 100644 --- a/tests/providers/apache/hive/operators/test_hive_to_samba.py +++ b/tests/providers/apache/hive/operators/test_hive_to_samba.py @@ -15,15 +15,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +import os import unittest from unittest.mock import Mock, PropertyMock, patch from airflow.providers.apache.hive.operators.hive_to_samba import Hive2SambaOperator from airflow.utils.operator_helpers import context_to_airflow_vars +from tests.providers.apache.hive import DEFAULT_DATE, TestHiveEnvironment -class TestHive2SambaOperator(unittest.TestCase): +class TestHive2SambaOperator(TestHiveEnvironment): def setUp(self): self.kwargs = dict( @@ -32,8 +33,8 @@ def setUp(self): samba_conn_id='samba_default', hiveserver2_conn_id='hiveserver2_default', task_id='test_hive_to_samba_operator', - dag=None ) + super().setUp() @patch('airflow.providers.apache.hive.operators.hive_to_samba.SambaHook') @patch('airflow.providers.apache.hive.operators.hive_to_samba.HiveServer2Hook') @@ -53,3 +54,16 @@ def test_execute(self, mock_tmp_file, mock_hive_hook, mock_samba_hook): mock_samba_hook.assert_called_once_with(samba_conn_id=self.kwargs['samba_conn_id']) mock_samba_hook.return_value.push_from_local.assert_called_once_with( self.kwargs['destination_filepath'], mock_tmp_file.name) + + @unittest.skipIf( + 'AIRFLOW_RUNALL_TESTS' not in os.environ, + "Skipped because AIRFLOW_RUNALL_TESTS is not set") + def test_hive2samba(self): + op = Hive2SambaOperator( + task_id='hive2samba_check', + samba_conn_id='tableau_samba', + hql="SELECT * FROM airflow.static_babynames LIMIT 10000", + destination_filepath='test_airflow.csv', + dag=self.dag) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, + ignore_ti_state=True) diff --git a/tests/providers/apache/hive/sensors/test_hdfs.py b/tests/providers/apache/hive/sensors/test_hdfs.py new file mode 100644 index 0000000000000..b658cdbeceafa --- /dev/null +++ b/tests/providers/apache/hive/sensors/test_hdfs.py @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import unittest + +from airflow.providers.apache.hdfs.sensors.hdfs import HdfsSensor +from tests.providers.apache.hive import DEFAULT_DATE, TestHiveEnvironment + + +@unittest.skipIf( + 'AIRFLOW_RUNALL_TESTS' not in os.environ, + "Skipped because AIRFLOW_RUNALL_TESTS is not set") +class TestHdfsSensor(TestHiveEnvironment): + + def test_hdfs_sensor(self): + op = HdfsSensor( + task_id='hdfs_sensor_check', + filepath='hdfs://user/hive/warehouse/airflow.db/static_babynames', + dag=self.dag) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, + ignore_ti_state=True) diff --git a/tests/providers/apache/hive/sensors/test_hive_partition.py b/tests/providers/apache/hive/sensors/test_hive_partition.py new file mode 100644 index 0000000000000..1407bef735f04 --- /dev/null +++ b/tests/providers/apache/hive/sensors/test_hive_partition.py @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import unittest + +from airflow.providers.apache.hive.sensors.hive_partition import HivePartitionSensor +from tests.providers.apache.hive import DEFAULT_DATE, TestHiveEnvironment + + +@unittest.skipIf( + 'AIRFLOW_RUNALL_TESTS' not in os.environ, + "Skipped because AIRFLOW_RUNALL_TESTS is not set") +class TestHivePartitionSensor(TestHiveEnvironment): + + def test_hive_partition_sensor(self): + op = HivePartitionSensor( + task_id='hive_partition_check', + table='airflow.static_babynames_partitioned', + dag=self.dag) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, + ignore_ti_state=True) diff --git a/tests/providers/apache/hive/sensors/test_metastore_partition.py b/tests/providers/apache/hive/sensors/test_metastore_partition.py new file mode 100644 index 0000000000000..68fad19fdf8d3 --- /dev/null +++ b/tests/providers/apache/hive/sensors/test_metastore_partition.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import unittest + +from airflow.providers.apache.hive.sensors.metastore_partition import MetastorePartitionSensor +from tests.providers.apache.hive import DEFAULT_DATE, DEFAULT_DATE_DS, TestHiveEnvironment + + +@unittest.skipIf( + 'AIRFLOW_RUNALL_TESTS' not in os.environ, + "Skipped because AIRFLOW_RUNALL_TESTS is not set") +class TestHivePartitionSensor(TestHiveEnvironment): + + def test_hive_metastore_sql_sensor(self): + op = MetastorePartitionSensor( + task_id='hive_partition_check', + table='airflow.static_babynames_partitioned', + partition_name='ds={}'.format(DEFAULT_DATE_DS), + dag=self.dag) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, + ignore_ti_state=True) diff --git a/tests/providers/apache/hive/sensors/test_named_hive_partition.py b/tests/providers/apache/hive/sensors/test_named_hive_partition.py index 2187196b649cd..2ba1f91a96d53 100644 --- a/tests/providers/apache/hive/sensors/test_named_hive_partition.py +++ b/tests/providers/apache/hive/sensors/test_named_hive_partition.py @@ -15,14 +15,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import os import random import unittest from datetime import timedelta -from airflow import DAG, operators +from airflow import DAG +from airflow.exceptions import AirflowSensorTimeout from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook +from airflow.providers.apache.hive.operators.hive import HiveOperator from airflow.providers.apache.hive.sensors.named_hive_partition import NamedHivePartitionSensor from airflow.utils.timezone import datetime +from tests.providers.apache.hive import TestHiveEnvironment DEFAULT_DATE = datetime(2015, 1, 1) DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat() @@ -53,7 +57,7 @@ def setUp(self): ADD PARTITION({{ params.partition_by }}='{{ ds }}'); """ self.hook = HiveMetastoreHook() - op = operators.hive_operator.HiveOperator( + op = HiveOperator( task_id='HiveHook_' + str(random.randint(1, 10000)), params={ 'database': self.database, @@ -124,3 +128,51 @@ def test_poke_non_existing(self): hook=self.hook, dag=self.dag) self.assertFalse(sensor.poke(None)) + + +@unittest.skipIf( + 'AIRFLOW_RUNALL_TESTS' not in os.environ, + "Skipped because AIRFLOW_RUNALL_TESTS is not set") +class TestPartitions(TestHiveEnvironment): + + def test_succeeds_on_one_partition(self): + op = NamedHivePartitionSensor( + task_id='hive_partition_check', + partition_names=[ + "airflow.static_babynames_partitioned/ds={{ds}}" + ], + dag=self.dag) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, + ignore_ti_state=True) + + def test_succeeds_on_multiple_partitions(self): + op = NamedHivePartitionSensor( + task_id='hive_partition_check', + partition_names=[ + "airflow.static_babynames_partitioned/ds={{ds}}", + "airflow.static_babynames_partitioned/ds={{ds}}" + ], + dag=self.dag) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, + ignore_ti_state=True) + + def test_parses_partitions_with_periods(self): + name = NamedHivePartitionSensor.parse_partition_name( + partition="schema.table/part1=this.can.be.an.issue/part2=ok") + self.assertEqual(name[0], "schema") + self.assertEqual(name[1], "table") + self.assertEqual(name[2], "part1=this.can.be.an.issue/part2=ok") + + def test_times_out_on_nonexistent_partition(self): + with self.assertRaises(AirflowSensorTimeout): + op = NamedHivePartitionSensor( + task_id='hive_partition_check', + partition_names=[ + "airflow.static_babynames_partitioned/ds={{ds}}", + "airflow.static_babynames_partitioned/ds=nonexistent" + ], + poke_interval=0.1, + timeout=1, + dag=self.dag) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, + ignore_ti_state=True) diff --git a/tests/providers/mysql/operators/test_presto_to_mysql.py b/tests/providers/mysql/operators/test_presto_to_mysql.py index 180dc89d20115..ee877b38d8301 100644 --- a/tests/providers/mysql/operators/test_presto_to_mysql.py +++ b/tests/providers/mysql/operators/test_presto_to_mysql.py @@ -15,22 +15,23 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +import os import unittest from unittest.mock import patch from airflow.providers.mysql.operators.presto_to_mysql import PrestoToMySqlTransfer +from tests.providers.apache.hive import DEFAULT_DATE, TestHiveEnvironment -class TestPrestoToMySqlTransfer(unittest.TestCase): +class TestPrestoToMySqlTransfer(TestHiveEnvironment): def setUp(self): self.kwargs = dict( sql='sql', mysql_table='mysql_table', task_id='test_presto_to_mysql_transfer', - dag=None ) + super().setUp() @patch('airflow.providers.mysql.operators.presto_to_mysql.MySqlHook') @patch('airflow.providers.mysql.operators.presto_to_mysql.PrestoHook') @@ -52,3 +53,20 @@ def test_execute_with_mysql_preoperator(self, mock_presto_hook, mock_mysql_hook) mock_mysql_hook.return_value.run.assert_called_once_with(self.kwargs['mysql_preoperator']) mock_mysql_hook.return_value.insert_rows.assert_called_once_with( table=self.kwargs['mysql_table'], rows=mock_presto_hook.return_value.get_records.return_value) + + @unittest.skipIf( + 'AIRFLOW_RUNALL_TESTS' not in os.environ, + "Skipped because AIRFLOW_RUNALL_TESTS is not set") + def test_presto_to_mysql(self): + op = PrestoToMySqlTransfer( + task_id='presto_to_mysql_check', + sql=""" + SELECT name, count(*) as ccount + FROM airflow.static_babynames + GROUP BY name + """, + mysql_table='test_static_babynames', + mysql_preoperator='TRUNCATE TABLE test_static_babynames;', + dag=self.dag) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, + ignore_ti_state=True) diff --git a/tests/providers/presto/operators/test_presto_check.py b/tests/providers/presto/operators/test_presto_check.py new file mode 100644 index 0000000000000..e290343922c49 --- /dev/null +++ b/tests/providers/presto/operators/test_presto_check.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import os +import unittest + +from airflow.providers.presto.operators.presto_check import PrestoCheckOperator +from tests.providers.apache.hive import DEFAULT_DATE, TestHiveEnvironment + + +@unittest.skipIf( + 'AIRFLOW_RUNALL_TESTS' not in os.environ, + "Skipped because AIRFLOW_RUNALL_TESTS is not set") +class TestPrestoCheckOperator(TestHiveEnvironment): + + def test_presto(self): + sql = """ + SELECT count(1) FROM airflow.static_babynames_partitioned; + """ + op = PrestoCheckOperator( + task_id='presto_check', sql=sql, dag=self.dag) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, + ignore_ti_state=True) diff --git a/tests/sensors/test_sql_sensor.py b/tests/sensors/test_sql_sensor.py index cc228e14e3b87..64bd0fca60566 100644 --- a/tests/sensors/test_sql_sensor.py +++ b/tests/sensors/test_sql_sensor.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import os import unittest from unittest import mock @@ -24,14 +25,16 @@ from airflow.exceptions import AirflowException from airflow.sensors.sql_sensor import SqlSensor from airflow.utils.timezone import datetime +from tests.providers.apache.hive import TestHiveEnvironment DEFAULT_DATE = datetime(2015, 1, 1) TEST_DAG_ID = 'unit_test_sql_dag' -class TestSqlSensor(unittest.TestCase): +class TestSqlSensor(TestHiveEnvironment): def setUp(self): + super().setUp() args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE @@ -243,3 +246,15 @@ def test_sql_sensor_postgres_poke_invalid_success(self, mock_hook): mock_get_records.return_value = [[1]] self.assertRaises(AirflowException, op.poke, None) + + @unittest.skipIf( + 'AIRFLOW_RUNALL_TESTS' not in os.environ, + "Skipped because AIRFLOW_RUNALL_TESTS is not set") + def test_sql_sensor_presto(self): + op = SqlSensor( + task_id='hdfs_sensor_check', + conn_id='presto_default', + sql="SELECT 'x' FROM airflow.static_babynames LIMIT 1;", + dag=self.dag) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, + ignore_ti_state=True) diff --git a/tests/test_project_structure.py b/tests/test_project_structure.py index 6c282a2a81bdf..382f00468f027 100644 --- a/tests/test_project_structure.py +++ b/tests/test_project_structure.py @@ -31,8 +31,6 @@ 'tests/providers/apache/cassandra/sensors/test_table.py', 'tests/providers/apache/hdfs/sensors/test_web_hdfs.py', 'tests/providers/apache/hive/operators/test_vertica_to_hive.py', - 'tests/providers/apache/hive/sensors/test_hive_partition.py', - 'tests/providers/apache/hive/sensors/test_metastore_partition.py', 'tests/providers/apache/pig/operators/test_pig.py', 'tests/providers/apache/spark/hooks/test_spark_jdbc_script.py', 'tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py', @@ -49,7 +47,6 @@ 'tests/providers/microsoft/mssql/hooks/test_mssql.py', 'tests/providers/microsoft/mssql/operators/test_mssql.py', 'tests/providers/oracle/operators/test_oracle.py', - 'tests/providers/presto/operators/test_presto_check.py', 'tests/providers/qubole/hooks/test_qubole.py', 'tests/providers/samba/hooks/test_samba.py', 'tests/providers/sqlite/operators/test_sqlite.py',