Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def do_setup():
zip_safe=False,
scripts=['airflow/bin/airflow'],
install_requires=[
'alembic>=0.8.3, <0.9',
'alembic>=0.9, <1.0',
'bleach~=2.1.3',
'configparser>=3.5.0, <3.6.0',
'croniter>=0.3.17, <0.4',
Expand Down
27 changes: 17 additions & 10 deletions tests/hooks/test_postgres_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@

class TestPostgresHook(unittest.TestCase):

def __init__(self, *args, **kwargs):
super(TestPostgresHook, self).__init__(*args, **kwargs)
self.table = "test_postgres_hook_table"

def setUp(self):
super(TestPostgresHook, self).setUp()

Expand All @@ -43,6 +47,13 @@ def get_conn(self):

self.db_hook = UnitTestPostgresHook()

def tearDown(self):
super(TestPostgresHook, self).tearDown()

with PostgresHook().get_conn() as conn:
with conn.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS {}".format(self.table))

def test_copy_expert(self):
m = mock.mock_open(read_data='{"some": "json"}')
with mock.patch('airflow.hooks.postgres_hook.open', m):
Expand All @@ -61,40 +72,36 @@ def test_copy_expert(self):

def test_bulk_load(self):
hook = PostgresHook()
table = "t"
input_data = ["foo", "bar", "baz"]

with hook.get_conn() as conn:
with conn.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS {}".format(table))
cur.execute("CREATE TABLE {} (c VARCHAR)".format(table))
cur.execute("CREATE TABLE {} (c VARCHAR)".format(self.table))
conn.commit()

with NamedTemporaryFile() as f:
f.write("\n".join(input_data).encode("utf-8"))
f.flush()
hook.bulk_load(table, f.name)
hook.bulk_load(self.table, f.name)

cur.execute("SELECT * FROM {}".format(table))
cur.execute("SELECT * FROM {}".format(self.table))
results = [row[0] for row in cur.fetchall()]

self.assertEqual(sorted(input_data), sorted(results))

def test_bulk_dump(self):
hook = PostgresHook()
table = "t"
input_data = ["foo", "bar", "baz"]

with hook.get_conn() as conn:
with conn.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS {}".format(table))
cur.execute("CREATE TABLE {} (c VARCHAR)".format(table))
cur.execute("CREATE TABLE {} (c VARCHAR)".format(self.table))
values = ",".join("('{}')".format(data) for data in input_data)
cur.execute("INSERT INTO {} VALUES {}".format(table, values))
cur.execute("INSERT INTO {} VALUES {}".format(self.table, values))
conn.commit()

with NamedTemporaryFile() as f:
hook.bulk_dump(table, f.name)
hook.bulk_dump(self.table, f.name)
f.seek(0)
results = [line.rstrip().decode("utf-8") for line in f.readlines()]

Expand Down
7 changes: 7 additions & 0 deletions tests/operators/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ def setUp(self):
dag = DAG(TEST_DAG_ID, default_args=args)
self.dag = dag

def tearDown(self):
from airflow.hooks.mysql_hook import MySqlHook
drop_tables = {'test_mysql_to_mysql', 'test_airflow'}
with MySqlHook().get_conn() as conn:
for table in drop_tables:
conn.execute("DROP TABLE IF EXISTS {}".format(table))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this test for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously the tables created by the tests would still exists after the test since there was no cleaning up. This interferes with the test that compares the Alembic models with the actual database.

For Travis we run all the tests exactly once, but we still want to clean up the state (tables in the database) afterwards to exit in a consistent state. Therefore I've added the some cleaning in the tearDown step of the tests.


def test_mysql_operator_test(self):
sql = """
CREATE TABLE IF NOT EXISTS test_airflow (
Expand Down
39 changes: 29 additions & 10 deletions tests/utils/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,19 @@ def test_database_schema_and_sqlalchemy_model_are_in_sync(self):
lambda t: (t[0] == 'remove_column' and
t[2] == 'users' and
t[3].name == 'password'),
# ignore tables created by other tests
lambda t: (t[0] == 'remove_table' and
t[1].name == 't'),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice that those can be removed 👍

lambda t: (t[0] == 'remove_table' and
t[1].name == 'test_airflow'),
lambda t: (t[0] == 'remove_table' and
t[1].name == 'test_postgres_to_postgres'),
lambda t: (t[0] == 'remove_table' and
t[1].name == 'test_mysql_to_mysql'),

# ignore tables created by celery
lambda t: (t[0] == 'remove_table' and
t[1].name == 'celery_taskmeta'),
lambda t: (t[0] == 'remove_table' and
t[1].name == 'celery_tasksetmeta'),

# ignore indices created by celery
lambda t: (t[0] == 'remove_index' and
t[1].name == 'task_id'),
lambda t: (t[0] == 'remove_index' and
t[1].name == 'taskset_id'),

# Ignore all the fab tables
lambda t: (t[0] == 'remove_table' and
t[1].name == 'ab_permission'),
Expand All @@ -76,11 +75,31 @@ def test_database_schema_and_sqlalchemy_model_are_in_sync(self):
t[1].name == 'ab_user'),
lambda t: (t[0] == 'remove_table' and
t[1].name == 'ab_view_menu'),

# Ignore all the fab indices
lambda t: (t[0] == 'remove_index' and
t[1].name == 'permission_id'),
lambda t: (t[0] == 'remove_index' and
t[1].name == 'name'),
lambda t: (t[0] == 'remove_index' and
t[1].name == 'user_id'),
lambda t: (t[0] == 'remove_index' and
t[1].name == 'username'),
lambda t: (t[0] == 'remove_index' and
t[1].name == 'field_string'),
lambda t: (t[0] == 'remove_index' and
t[1].name == 'email'),
lambda t: (t[0] == 'remove_index' and
t[1].name == 'permission_view_id'),

# from test_security unit test
lambda t: (t[0] == 'remove_table' and
t[1].name == 'some_model'),
]
for ignore in ignores:
diff = [d for d in diff if not ignore(d)]

self.assertFalse(diff, 'Database schema and SQLAlchemy model are not in sync')
self.assertFalse(
diff,
'Database schema and SQLAlchemy model are not in sync: ' + str(diff)
)