From 87d8a72b7d1fcbee7d0df499f69f6fcca2ea27ec Mon Sep 17 00:00:00 2001 From: Joffrey Bienvenu Date: Mon, 22 Jan 2024 13:05:25 +0100 Subject: [PATCH 1/4] feat: Enable column renaming on DatabricksSQLHook --- .../databricks/hooks/databricks_sql.py | 4 ++-- .../databricks/hooks/test_databricks_sql.py | 20 +++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks_sql.py b/airflow/providers/databricks/hooks/databricks_sql.py index 5ffe404e317e5..a7ebcdc180c06 100644 --- a/airflow/providers/databricks/hooks/databricks_sql.py +++ b/airflow/providers/databricks/hooks/databricks_sql.py @@ -283,12 +283,12 @@ def _make_common_data_structure(self, result: Sequence[Row] | Row) -> list[tuple if not rows: return [] rows_fields = rows[0].__fields__ - rows_object = namedtuple("Row", rows_fields) # type: ignore[misc] + rows_object = namedtuple("Row", rows_fields, rename=True) # type: ignore[misc] return cast(List[tuple], [rows_object(*row) for row in rows]) else: row: Row = result row_fields = row.__fields__ - row_object = namedtuple("Row", row_fields) # type: ignore[misc] + row_object = namedtuple("Row", row_fields, rename=True) # type: ignore[misc] return cast(tuple, row_object(*row)) def bulk_dump(self, table, tmp_file): diff --git a/tests/providers/databricks/hooks/test_databricks_sql.py b/tests/providers/databricks/hooks/test_databricks_sql.py index 01392fb0a7dc3..5d56729109a94 100644 --- a/tests/providers/databricks/hooks/test_databricks_sql.py +++ b/tests/providers/databricks/hooks/test_databricks_sql.py @@ -287,3 +287,23 @@ def test_no_query(databricks_hook, empty_statement): with pytest.raises(ValueError) as err: databricks_hook.run(sql=empty_statement) assert err.value.args[0] == "List of SQL statements is empty" + + +@pytest.mark.parametrize( + "row_objects, fields_names", + [ + pytest.param(Row("count(1)")(9714), ("_0",)), + pytest.param(Row("1//@:()")("data"), ("_0",)), + pytest.param(Row("class")("data"), ("_0",)), + pytest.param(Row("1_wrong", "2_wrong")(1, 2), ("_0", "_1")), + ], + +) +def test_incorrect_column_names(row_objects, fields_names): + """Ensure that column names can be used as namedtuple attribute. + + namedtuple do not accept special characters and reserved python keywords + as column name. This test ensure that such columns are renamed. + """ + result = DatabricksSqlHook()._make_common_data_structure(row_objects) + assert result._fields == fields_names From b160c4b68a726851a43ebfb683e853d458be6ace Mon Sep 17 00:00:00 2001 From: Joffrey Bienvenu Date: Mon, 22 Jan 2024 14:30:34 +0100 Subject: [PATCH 2/4] feat: Enable column renaming on ODBCHook --- airflow/providers/odbc/hooks/odbc.py | 17 +++++++++-------- .../databricks/hooks/test_databricks_sql.py | 1 - 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/airflow/providers/odbc/hooks/odbc.py b/airflow/providers/odbc/hooks/odbc.py index 7b0c4d11e9e47..71a838f17c7de 100644 --- a/airflow/providers/odbc/hooks/odbc.py +++ b/airflow/providers/odbc/hooks/odbc.py @@ -17,7 +17,8 @@ """This module contains ODBC hook.""" from __future__ import annotations -from typing import Any, List, NamedTuple, Sequence, cast +from collections import namedtuple +from typing import Any, List, Sequence, cast from urllib.parse import quote_plus from pyodbc import Connection, Row, connect @@ -229,16 +230,16 @@ def get_sqlalchemy_connection( return cnx def _make_common_data_structure(self, result: Sequence[Row] | Row) -> list[tuple] | tuple: - """Transform the pyodbc.Row objects returned from an SQL command into typed NamedTuples.""" - # Below ignored lines respect NamedTuple docstring, but mypy do not support dynamically - # instantiated typed Namedtuple, and will never do: https://github.com/python/mypy/issues/848 + """Transform the pyodbc.Row objects returned from an SQL command into namedtuples.""" + # Below ignored lines respect namedtuple docstring, but mypy do not support dynamically + # instantiated namedtuple, and will never do: https://github.com/python/mypy/issues/848 field_names: list[tuple[str, type]] | None = None if not result: return [] if isinstance(result, Sequence): - field_names = [col[:2] for col in result[0].cursor_description] - row_object = NamedTuple("Row", field_names) # type: ignore[misc] + field_names = [col[0] for col in result[0].cursor_description] + row_object = namedtuple("Row", field_names, rename=True) # type: ignore[misc] return cast(List[tuple], [row_object(*row) for row in result]) else: - field_names = [col[:2] for col in result.cursor_description] - return cast(tuple, NamedTuple("Row", field_names)(*result)) # type: ignore[misc, operator] + field_names = [col[0] for col in result.cursor_description] + return cast(tuple, namedtuple("Row", field_names, rename=True)(*result)) # type: ignore[misc, operator] diff --git a/tests/providers/databricks/hooks/test_databricks_sql.py b/tests/providers/databricks/hooks/test_databricks_sql.py index 5d56729109a94..6fb4629995737 100644 --- a/tests/providers/databricks/hooks/test_databricks_sql.py +++ b/tests/providers/databricks/hooks/test_databricks_sql.py @@ -297,7 +297,6 @@ def test_no_query(databricks_hook, empty_statement): pytest.param(Row("class")("data"), ("_0",)), pytest.param(Row("1_wrong", "2_wrong")(1, 2), ("_0", "_1")), ], - ) def test_incorrect_column_names(row_objects, fields_names): """Ensure that column names can be used as namedtuple attribute. From df534cd48653b675376bd018c6f1e9816101c1ce Mon Sep 17 00:00:00 2001 From: Joffrey Bienvenu Date: Mon, 22 Jan 2024 14:54:42 +0100 Subject: [PATCH 3/4] fix: Set generic ignore for namedtuple --- airflow/providers/odbc/hooks/odbc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/providers/odbc/hooks/odbc.py b/airflow/providers/odbc/hooks/odbc.py index 71a838f17c7de..fae37c56d24a1 100644 --- a/airflow/providers/odbc/hooks/odbc.py +++ b/airflow/providers/odbc/hooks/odbc.py @@ -238,8 +238,8 @@ def _make_common_data_structure(self, result: Sequence[Row] | Row) -> list[tuple return [] if isinstance(result, Sequence): field_names = [col[0] for col in result[0].cursor_description] - row_object = namedtuple("Row", field_names, rename=True) # type: ignore[misc] + row_object = namedtuple("Row", field_names, rename=True) # type: ignore return cast(List[tuple], [row_object(*row) for row in result]) else: field_names = [col[0] for col in result.cursor_description] - return cast(tuple, namedtuple("Row", field_names, rename=True)(*result)) # type: ignore[misc, operator] + return cast(tuple, namedtuple("Row", field_names, rename=True)(*result)) # type: ignore From 683437dd4c0b6e98974e84310ad1563ed1a45eca Mon Sep 17 00:00:00 2001 From: Joffrey Bienvenu Date: Mon, 22 Jan 2024 15:21:29 +0100 Subject: [PATCH 4/4] fix: Correctly use Databricks Row fields when Row is created in two steps --- .../providers/databricks/hooks/databricks_sql.py | 8 ++++---- .../databricks/hooks/test_databricks_sql.py | 14 +++++++++++++- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks_sql.py b/airflow/providers/databricks/hooks/databricks_sql.py index a7ebcdc180c06..bc701c9975f2c 100644 --- a/airflow/providers/databricks/hooks/databricks_sql.py +++ b/airflow/providers/databricks/hooks/databricks_sql.py @@ -282,13 +282,13 @@ def _make_common_data_structure(self, result: Sequence[Row] | Row) -> list[tuple rows: list[Row] = result if not rows: return [] - rows_fields = rows[0].__fields__ - rows_object = namedtuple("Row", rows_fields, rename=True) # type: ignore[misc] + rows_fields = tuple(rows[0].__fields__) + rows_object = namedtuple("Row", rows_fields, rename=True) # type: ignore return cast(List[tuple], [rows_object(*row) for row in rows]) else: row: Row = result - row_fields = row.__fields__ - row_object = namedtuple("Row", row_fields, rename=True) # type: ignore[misc] + row_fields = tuple(row.__fields__) + row_object = namedtuple("Row", row_fields, rename=True) # type: ignore return cast(tuple, row_object(*row)) def bulk_dump(self, table, tmp_file): diff --git a/tests/providers/databricks/hooks/test_databricks_sql.py b/tests/providers/databricks/hooks/test_databricks_sql.py index 6fb4629995737..b118fdb95e413 100644 --- a/tests/providers/databricks/hooks/test_databricks_sql.py +++ b/tests/providers/databricks/hooks/test_databricks_sql.py @@ -193,6 +193,18 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ], id="The return_last not set on multiple queries not set", ), + pytest.param( + True, + False, + "select * from test.test", + ["select * from test.test"], + True, + [["id", "value"]], + ([Row("id", "value")(1, 2)],), + [[("id",), ("value",)]], + [SerializableRow(1, 2)], + id="Return a serializable row (tuple) from a row instance created in two step", + ), pytest.param( True, False, @@ -203,7 +215,7 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ([Row(id=1, value=2)],), [[("id",), ("value",)]], [SerializableRow(1, 2)], - id="The return_last set and no split statements set on single query in string", + id="Return a serializable row (tuple) from a row instance created in one step", ), pytest.param( True,