diff --git a/airflow/providers/odbc/CHANGELOG.rst b/airflow/providers/odbc/CHANGELOG.rst index 1a6242f98a62b..3b15ef16b79eb 100644 --- a/airflow/providers/odbc/CHANGELOG.rst +++ b/airflow/providers/odbc/CHANGELOG.rst @@ -24,6 +24,17 @@ Changelog --------- +4.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +The driver parameter has to be passed via keyword ``driver`` argument when initializing the Hook or via +``hook_params`` dictionary (with ``driver`` key) when instantiating Hook from SQL Operators. It was possible +to instantiate it via extras before, but in this version, only setting it via constructor is supported. + + 3.3.0 ..... diff --git a/airflow/providers/odbc/hooks/odbc.py b/airflow/providers/odbc/hooks/odbc.py index 72f137ff32a27..1b125ae2f215f 100644 --- a/airflow/providers/odbc/hooks/odbc.py +++ b/airflow/providers/odbc/hooks/odbc.py @@ -34,11 +34,17 @@ class OdbcHook(DbApiHook): :param args: passed to DbApiHook :param database: database to use -- overrides connection ``schema`` - :param driver: name of driver or path to driver. overrides driver supplied in connection ``extra`` + :param driver: name of driver or path to driver. You can also set the driver via: + * setting ``driver`` parameter in ``hook_params`` dictionary when instantiating hook by SQL operators. + * setting `driver`` extra in the connection and setting ``allow_driver_extra`` to True. + * setting ``OdbcHook.default_driver`` in ``local_settings.py`` file. :param dsn: name of DSN to use. overrides DSN supplied in connection ``extra`` :param connect_kwargs: keyword arguments passed to ``pyodbc.connect`` :param sqlalchemy_scheme: Scheme sqlalchemy connection. Default is ``mssql+pyodbc`` Only used for ``get_sqlalchemy_engine`` and ``get_sqlalchemy_connection`` methods. + :param allow_driver_extra: If True, allows to use driver extra in connection string (default False). + You should make sure that you trust the users who can edit connections in the UI to not use it + maliciously. :param kwargs: passed to DbApiHook """ @@ -49,6 +55,8 @@ class OdbcHook(DbApiHook): hook_name = "ODBC" supports_autocommit = True + default_driver: str | None = None + def __init__( self, *args, @@ -57,6 +65,7 @@ def __init__( dsn: str | None = None, connect_kwargs: dict | None = None, sqlalchemy_scheme: str | None = None, + allow_driver_extra: bool = False, **kwargs, ) -> None: super().__init__(*args, **kwargs) @@ -67,6 +76,7 @@ def __init__( self._sqlalchemy_scheme = sqlalchemy_scheme self._connection = None self._connect_kwargs = connect_kwargs + self._allow_driver_extra = allow_driver_extra @property def connection(self): @@ -101,11 +111,19 @@ def connection_extra_lower(self) -> dict: @property def driver(self) -> str | None: """Driver from init param if given; else try to find one in connection extra.""" + extra_driver = self.connection_extra_lower.get("driver") + if extra_driver: + if self._allow_driver_extra: + self._driver = extra_driver + else: + self.log.warning( + "Please provide driver via 'driver' parameter of the Hook constructor" + " or via 'hook_params' dictionary 'driver' key when instantiating hook by the" + " SQL operators. The 'driver' extra will not be used." + ) if not self._driver: - driver = self.connection_extra_lower.get("driver") - if driver: - self._driver = driver - return self._driver and self._driver.strip().lstrip("{").rstrip("}").strip() + self._driver = self.default_driver + return self._driver.strip().lstrip("{").rstrip("}").strip() if self._driver else None @property def dsn(self) -> str | None: diff --git a/tests/providers/odbc/hooks/test_odbc.py b/tests/providers/odbc/hooks/test_odbc.py index 58dcaf10d7b1f..3d21806cf98a7 100644 --- a/tests/providers/odbc/hooks/test_odbc.py +++ b/tests/providers/odbc/hooks/test_odbc.py @@ -18,7 +18,9 @@ from __future__ import annotations import json +import logging from unittest import mock +from unittest.mock import patch from urllib.parse import quote_plus, urlsplit import pyodbc @@ -43,11 +45,12 @@ def get_hook(self=None, hook_params=None, conn_params=None): hook.get_connection.return_value = connection return hook - def test_driver_in_extra(self): + def test_driver_in_extra_not_used(self): conn_params = dict(extra=json.dumps(dict(Driver="Fake Driver", Fake_Param="Fake Param"))) - hook = self.get_hook(conn_params=conn_params) + hook_params = {"driver": "ParamDriver"} + hook = self.get_hook(conn_params=conn_params, hook_params=hook_params) expected = ( - "DRIVER={Fake Driver};" + "DRIVER={ParamDriver};" "SERVER=host;" "DATABASE=schema;" "UID=login;" @@ -177,11 +180,40 @@ def test_driver(self): assert hook.driver == "Blah driver" hook = self.get_hook(hook_params=dict(driver="{Blah driver}")) assert hook.driver == "Blah driver" - hook = self.get_hook(conn_params=dict(extra='{"driver": "Blah driver"}')) - assert hook.driver == "Blah driver" - hook = self.get_hook(conn_params=dict(extra='{"driver": "{Blah driver}"}')) + + def test_driver_extra_raises_warning_by_default(self, caplog): + with caplog.at_level(logging.WARNING, logger="airflow.providers.odbc.hooks.test_odbc"): + driver = self.get_hook(conn_params=dict(extra='{"driver": "Blah driver"}')).driver + assert "Please provide driver via 'driver' parameter of the Hook" in caplog.text + assert driver is None + + def test_driver_extra_works_when_allow_driver_extra(self): + hook = self.get_hook( + conn_params=dict(extra='{"driver": "Blah driver"}'), hook_params=dict(allow_driver_extra=True) + ) assert hook.driver == "Blah driver" + def test_default_driver_set(self): + with patch.object(OdbcHook, "default_driver", "Blah driver"): + hook = self.get_hook() + assert hook.driver == "Blah driver" + + def test_driver_extra_works_when_default_driver_set(self): + with patch.object(OdbcHook, "default_driver", "Blah driver"): + hook = self.get_hook() + assert hook.driver == "Blah driver" + + def test_driver_none_by_default(self): + hook = self.get_hook() + assert hook.driver is None + + def test_driver_extra_raises_warning_and_returns_default_driver_by_default(self, caplog): + with patch.object(OdbcHook, "default_driver", "Blah driver"): + with caplog.at_level(logging.WARNING, logger="airflow.providers.odbc.hooks.test_odbc"): + driver = self.get_hook(conn_params=dict(extra='{"driver": "Blah driver2"}')).driver + assert "Please provide driver via 'driver' parameter of the Hook" in caplog.text + assert driver == "Blah driver" + def test_database(self): hook = self.get_hook(hook_params=dict(database="abc")) assert hook.database == "abc"