diff --git a/airflow/providers/databricks/hooks/databricks_sql.py b/airflow/providers/databricks/hooks/databricks_sql.py index 5ffe404e317e5..bc701c9975f2c 100644 --- a/airflow/providers/databricks/hooks/databricks_sql.py +++ b/airflow/providers/databricks/hooks/databricks_sql.py @@ -282,13 +282,13 @@ def _make_common_data_structure(self, result: Sequence[Row] | Row) -> list[tuple rows: list[Row] = result if not rows: return [] - rows_fields = rows[0].__fields__ - rows_object = namedtuple("Row", rows_fields) # type: ignore[misc] + rows_fields = tuple(rows[0].__fields__) + rows_object = namedtuple("Row", rows_fields, rename=True) # type: ignore return cast(List[tuple], [rows_object(*row) for row in rows]) else: row: Row = result - row_fields = row.__fields__ - row_object = namedtuple("Row", row_fields) # type: ignore[misc] + row_fields = tuple(row.__fields__) + row_object = namedtuple("Row", row_fields, rename=True) # type: ignore return cast(tuple, row_object(*row)) def bulk_dump(self, table, tmp_file): diff --git a/airflow/providers/odbc/hooks/odbc.py b/airflow/providers/odbc/hooks/odbc.py index 7b0c4d11e9e47..fae37c56d24a1 100644 --- a/airflow/providers/odbc/hooks/odbc.py +++ b/airflow/providers/odbc/hooks/odbc.py @@ -17,7 +17,8 @@ """This module contains ODBC hook.""" from __future__ import annotations -from typing import Any, List, NamedTuple, Sequence, cast +from collections import namedtuple +from typing import Any, List, Sequence, cast from urllib.parse import quote_plus from pyodbc import Connection, Row, connect @@ -229,16 +230,16 @@ def get_sqlalchemy_connection( return cnx def _make_common_data_structure(self, result: Sequence[Row] | Row) -> list[tuple] | tuple: - """Transform the pyodbc.Row objects returned from an SQL command into typed NamedTuples.""" - # Below ignored lines respect NamedTuple docstring, but mypy do not support dynamically - # instantiated typed Namedtuple, and will never do: https://github.com/python/mypy/issues/848 + """Transform the pyodbc.Row objects returned from an SQL command into namedtuples.""" + # Below ignored lines respect namedtuple docstring, but mypy do not support dynamically + # instantiated namedtuple, and will never do: https://github.com/python/mypy/issues/848 field_names: list[tuple[str, type]] | None = None if not result: return [] if isinstance(result, Sequence): - field_names = [col[:2] for col in result[0].cursor_description] - row_object = NamedTuple("Row", field_names) # type: ignore[misc] + field_names = [col[0] for col in result[0].cursor_description] + row_object = namedtuple("Row", field_names, rename=True) # type: ignore return cast(List[tuple], [row_object(*row) for row in result]) else: - field_names = [col[:2] for col in result.cursor_description] - return cast(tuple, NamedTuple("Row", field_names)(*result)) # type: ignore[misc, operator] + field_names = [col[0] for col in result.cursor_description] + return cast(tuple, namedtuple("Row", field_names, rename=True)(*result)) # type: ignore diff --git a/tests/providers/databricks/hooks/test_databricks_sql.py b/tests/providers/databricks/hooks/test_databricks_sql.py index 01392fb0a7dc3..b118fdb95e413 100644 --- a/tests/providers/databricks/hooks/test_databricks_sql.py +++ b/tests/providers/databricks/hooks/test_databricks_sql.py @@ -193,6 +193,18 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ], id="The return_last not set on multiple queries not set", ), + pytest.param( + True, + False, + "select * from test.test", + ["select * from test.test"], + True, + [["id", "value"]], + ([Row("id", "value")(1, 2)],), + [[("id",), ("value",)]], + [SerializableRow(1, 2)], + id="Return a serializable row (tuple) from a row instance created in two step", + ), pytest.param( True, False, @@ -203,7 +215,7 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ([Row(id=1, value=2)],), [[("id",), ("value",)]], [SerializableRow(1, 2)], - id="The return_last set and no split statements set on single query in string", + id="Return a serializable row (tuple) from a row instance created in one step", ), pytest.param( True, @@ -287,3 +299,22 @@ def test_no_query(databricks_hook, empty_statement): with pytest.raises(ValueError) as err: databricks_hook.run(sql=empty_statement) assert err.value.args[0] == "List of SQL statements is empty" + + +@pytest.mark.parametrize( + "row_objects, fields_names", + [ + pytest.param(Row("count(1)")(9714), ("_0",)), + pytest.param(Row("1//@:()")("data"), ("_0",)), + pytest.param(Row("class")("data"), ("_0",)), + pytest.param(Row("1_wrong", "2_wrong")(1, 2), ("_0", "_1")), + ], +) +def test_incorrect_column_names(row_objects, fields_names): + """Ensure that column names can be used as namedtuple attribute. + + namedtuple do not accept special characters and reserved python keywords + as column name. This test ensure that such columns are renamed. + """ + result = DatabricksSqlHook()._make_common_data_structure(row_objects) + assert result._fields == fields_names