From 667fd6ea96481f255bfc738a7b888f8675c39d29 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Thu, 13 Sep 2018 18:06:03 +0200 Subject: [PATCH] [AIRFLOW-3059] Log how many rows are read from Postgres To know how many data is being read from Postgres, it is nice to log this to the Airflow log. Previously when there was no data, it would still create a single file. This is not something that we want, and therefore we've changed this behaviour. Refactored the tests to make use of Postgres itself since we have it running. This makes the tests more realistic, instead of mocking everything. --- .../operators/postgres_to_gcs_operator.py | 54 ++++++---- .../test_postgres_to_gcs_operator.py | 100 +++++++++++------- 2 files changed, 94 insertions(+), 60 deletions(-) diff --git a/airflow/contrib/operators/postgres_to_gcs_operator.py b/airflow/contrib/operators/postgres_to_gcs_operator.py index 850d858f94c58..78da78ee2f20d 100644 --- a/airflow/contrib/operators/postgres_to_gcs_operator.py +++ b/airflow/contrib/operators/postgres_to_gcs_operator.py @@ -133,28 +133,38 @@ def _write_local_data_files(self, cursor): contain the data for the GCS objects. """ schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description)) - file_no = 0 - tmp_file_handle = NamedTemporaryFile(delete=True) - tmp_file_handles = {self.filename.format(file_no): tmp_file_handle} - - for row in cursor: - # Convert datetime objects to utc seconds, and decimals to floats - row = map(self.convert_types, row) - row_dict = dict(zip(schema, row)) - - s = json.dumps(row_dict, sort_keys=True) - if PY3: - s = s.encode('utf-8') - tmp_file_handle.write(s) - - # Append newline to make dumps BigQuery compatible. - tmp_file_handle.write(b'\n') - - # Stop if the file exceeds the file size limit. - if tmp_file_handle.tell() >= self.approx_max_file_size_bytes: - file_no += 1 - tmp_file_handle = NamedTemporaryFile(delete=True) - tmp_file_handles[self.filename.format(file_no)] = tmp_file_handle + tmp_file_handles = {} + row_no = 0 + + def _create_new_file(): + handle = NamedTemporaryFile(delete=True) + filename = self.filename.format(len(tmp_file_handles)) + tmp_file_handles[filename] = handle + return handle + + # Don't create a file if there is nothing to write + if cursor.rowcount > 0: + tmp_file_handle = _create_new_file() + + for row in cursor: + # Convert datetime objects to utc seconds, and decimals to floats + row = map(self.convert_types, row) + row_dict = dict(zip(schema, row)) + + s = json.dumps(row_dict, sort_keys=True) + if PY3: + s = s.encode('utf-8') + tmp_file_handle.write(s) + + # Append newline to make dumps BigQuery compatible. + tmp_file_handle.write(b'\n') + + # Stop if the file exceeds the file size limit. + if tmp_file_handle.tell() >= self.approx_max_file_size_bytes: + tmp_file_handle = _create_new_file() + row_no += 1 + + self.log.info('Received %s rows over %s files', row_no, len(tmp_file_handles)) return tmp_file_handles diff --git a/tests/contrib/operators/test_postgres_to_gcs_operator.py b/tests/contrib/operators/test_postgres_to_gcs_operator.py index ca72016974c0e..1b6e731c3b304 100644 --- a/tests/contrib/operators/test_postgres_to_gcs_operator.py +++ b/tests/contrib/operators/test_postgres_to_gcs_operator.py @@ -7,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,40 +25,66 @@ import sys import unittest -from airflow.contrib.operators.postgres_to_gcs_operator import PostgresToGoogleCloudStorageOperator +from airflow.hooks.postgres_hook import PostgresHook +from airflow.contrib.operators.postgres_to_gcs_operator import \ + PostgresToGoogleCloudStorageOperator try: - from unittest import mock + from unittest.mock import patch except ImportError: try: - import mock + from mock import patch except ImportError: mock = None -PY3 = sys.version_info[0] == 3 +TABLES = {'postgres_to_gcs_operator', 'postgres_to_gcs_operator_empty'} TASK_ID = 'test-postgres-to-gcs' -POSTGRES_CONN_ID = 'postgres_conn_test' -SQL = 'select 1' +POSTGRES_CONN_ID = 'postgres_default' +SQL = 'SELECT * FROM postgres_to_gcs_operator' BUCKET = 'gs://test' FILENAME = 'test_{}.ndjson' -# we expect the psycopg cursor to return encoded strs in py2 and decoded in py3 -if PY3: - ROWS = [('mock_row_content_1', 42), ('mock_row_content_2', 43), ('mock_row_content_3', 44)] - CURSOR_DESCRIPTION = (('some_str', 0), ('some_num', 1005)) -else: - ROWS = [(b'mock_row_content_1', 42), (b'mock_row_content_2', 43), (b'mock_row_content_3', 44)] - CURSOR_DESCRIPTION = ((b'some_str', 0), (b'some_num', 1005)) + NDJSON_LINES = [ b'{"some_num": 42, "some_str": "mock_row_content_1"}\n', b'{"some_num": 43, "some_str": "mock_row_content_2"}\n', b'{"some_num": 44, "some_str": "mock_row_content_3"}\n' ] SCHEMA_FILENAME = 'schema_test.json' -SCHEMA_JSON = b'[{"mode": "NULLABLE", "name": "some_str", "type": "STRING"}, {"mode": "REPEATED", "name": "some_num", "type": "INTEGER"}]' +SCHEMA_JSON = b'[{"mode": "NULLABLE", "name": "some_str", "type": "STRING"}, ' \ + b'{"mode": "NULLABLE", "name": "some_num", "type": "INTEGER"}]' class PostgresToGoogleCloudStorageOperatorTest(unittest.TestCase): + def setUp(self): + postgres = PostgresHook() + with postgres.get_conn() as conn: + with conn.cursor() as cur: + for table in TABLES: + cur.execute("DROP TABLE IF EXISTS {} CASCADE;".format(table)) + cur.execute("CREATE TABLE {}(some_str varchar, some_num integer);" + .format(table)) + + cur.execute( + "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);", + ('mock_row_content_1', 42) + ) + cur.execute( + "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);", + ('mock_row_content_2', 43) + ) + cur.execute( + "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);", + ('mock_row_content_3', 44) + ) + + def tearDown(self): + postgres = PostgresHook() + with postgres.get_conn() as conn: + with conn.cursor() as cur: + for table in TABLES: + cur.execute("DROP TABLE IF EXISTS {} CASCADE;".format(table)) + def test_init(self): """Test PostgresToGoogleCloudStorageOperator instance is properly initialized.""" op = PostgresToGoogleCloudStorageOperator( @@ -68,9 +94,8 @@ def test_init(self): self.assertEqual(op.bucket, BUCKET) self.assertEqual(op.filename, FILENAME) - @mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.PostgresHook') - @mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook') - def test_exec_success(self, gcs_hook_mock_class, pg_hook_mock_class): + @patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook') + def test_exec_success(self, gcs_hook_mock_class): """Test the execute function in case where the run is successful.""" op = PostgresToGoogleCloudStorageOperator( task_id=TASK_ID, @@ -79,10 +104,6 @@ def test_exec_success(self, gcs_hook_mock_class, pg_hook_mock_class): bucket=BUCKET, filename=FILENAME) - pg_hook_mock = pg_hook_mock_class.return_value - pg_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS) - pg_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION - gcs_hook_mock = gcs_hook_mock_class.return_value def _assert_upload(bucket, obj, tmp_filename, content_type): @@ -96,16 +117,9 @@ def _assert_upload(bucket, obj, tmp_filename, content_type): op.execute(None) - pg_hook_mock_class.assert_called_once_with(postgres_conn_id=POSTGRES_CONN_ID) - pg_hook_mock.get_conn().cursor().execute.assert_called_once_with(SQL, None) - - @mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.PostgresHook') - @mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook') - def test_file_splitting(self, gcs_hook_mock_class, pg_hook_mock_class): + @patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook') + def test_file_splitting(self, gcs_hook_mock_class): """Test that ndjson is split by approx_max_file_size_bytes param.""" - pg_hook_mock = pg_hook_mock_class.return_value - pg_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS) - pg_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION gcs_hook_mock = gcs_hook_mock_class.return_value expected_upload = { @@ -129,13 +143,23 @@ def _assert_upload(bucket, obj, tmp_filename, content_type): approx_max_file_size_bytes=len(expected_upload[FILENAME.format(0)])) op.execute(None) - @mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.PostgresHook') - @mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook') - def test_schema_file(self, gcs_hook_mock_class, pg_hook_mock_class): + @patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook') + def test_empty_query(self, gcs_hook_mock_class): + """If the sql returns no rows, we should not upload any files""" + gcs_hook_mock = gcs_hook_mock_class.return_value + + op = PostgresToGoogleCloudStorageOperator( + task_id=TASK_ID, + sql='SELECT * FROM postgres_to_gcs_operator_empty', + bucket=BUCKET, + filename=FILENAME) + op.execute(None) + + assert not gcs_hook_mock.upload.called, 'No data means no files in the bucket' + + @patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook') + def test_schema_file(self, gcs_hook_mock_class): """Test writing schema files.""" - pg_hook_mock = pg_hook_mock_class.return_value - pg_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS) - pg_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION gcs_hook_mock = gcs_hook_mock_class.return_value