diff --git a/tests/hooks/test_mysql_hook.py b/tests/hooks/test_mysql_hook.py index d112f880e9e3a..415d430fc4b77 100644 --- a/tests/hooks/test_mysql_hook.py +++ b/tests/hooks/test_mysql_hook.py @@ -18,12 +18,81 @@ # under the License. # +import json import mock import unittest +import MySQLdb.cursors + +from airflow import models from airflow.hooks.mysql_hook import MySqlHook +class TestMySqlHookConn(unittest.TestCase): + + def setUp(self): + super(TestMySqlHookConn, self).setUp() + + self.connection = models.Connection( + login='login', + password='password', + host='host', + schema='schema', + ) + + self.db_hook = MySqlHook() + self.db_hook.get_connection = mock.Mock() + self.db_hook.get_connection.return_value = self.connection + + @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') + def test_get_conn(self, mock_connect): + self.db_hook.get_conn() + mock_connect.assert_called_once() + args, kwargs = mock_connect.call_args + self.assertEqual(args, ()) + self.assertEqual(kwargs['user'], 'login') + self.assertEqual(kwargs['passwd'], 'password') + self.assertEqual(kwargs['host'], 'host') + self.assertEqual(kwargs['db'], 'schema') + + @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') + def test_get_conn_port(self, mock_connect): + self.connection.port = 3307 + self.db_hook.get_conn() + mock_connect.assert_called_once() + args, kwargs = mock_connect.call_args + self.assertEqual(args, ()) + self.assertEqual(kwargs['port'], 3307) + + @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') + def test_get_conn_charset(self, mock_connect): + self.connection.extra = json.dumps({'charset': 'utf-8'}) + self.db_hook.get_conn() + mock_connect.assert_called_once() + args, kwargs = mock_connect.call_args + self.assertEqual(args, ()) + self.assertEqual(kwargs['charset'], 'utf-8') + self.assertEqual(kwargs['use_unicode'], True) + + @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') + def test_get_conn_cursor(self, mock_connect): + self.connection.extra = json.dumps({'cursor': 'sscursor'}) + self.db_hook.get_conn() + mock_connect.assert_called_once() + args, kwargs = mock_connect.call_args + self.assertEqual(args, ()) + self.assertEqual(kwargs['cursorclass'], MySQLdb.cursors.SSCursor) + + @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') + def test_get_conn_local_infile(self, mock_connect): + self.connection.extra = json.dumps({'local_infile': True}) + self.db_hook.get_conn() + mock_connect.assert_called_once() + args, kwargs = mock_connect.call_args + self.assertEqual(args, ()) + self.assertEqual(kwargs['local_infile'], 1) + + class TestMySqlHook(unittest.TestCase): def setUp(self): @@ -85,3 +154,20 @@ def test_run_multi_queries(self): self.assertEqual(kwargs, {}) self.cur.execute.assert_called_with(sql[1]) self.conn.commit.assert_not_called() + + def test_bulk_load(self): + self.db_hook.bulk_load('table', '/tmp/file') + self.cur.execute.assert_called_once_with(""" + LOAD DATA LOCAL INFILE '/tmp/file' + INTO TABLE table + """) + + def test_bulk_dump(self): + self.db_hook.bulk_dump('table', '/tmp/file') + self.cur.execute.assert_called_once_with(""" + SELECT * INTO OUTFILE '/tmp/file' + FROM table + """) + + def test_serialize_cell(self): + self.assertEqual('foo', self.db_hook._serialize_cell('foo', None))