Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions airflow/providers/odbc/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@
Changelog
---------

4.0.0
.....

Breaking changes
~~~~~~~~~~~~~~~~

The driver parameter has to be passed via keyword ``driver`` argument when initializing the Hook or via
``hook_params`` dictionary (with ``driver`` key) when instantiating Hook from SQL Operators. It was possible
to instantiate it via extras before, but in this version, only setting it via constructor is supported.


3.3.0
.....

Expand Down
28 changes: 23 additions & 5 deletions airflow/providers/odbc/hooks/odbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,17 @@ class OdbcHook(DbApiHook):

:param args: passed to DbApiHook
:param database: database to use -- overrides connection ``schema``
:param driver: name of driver or path to driver. overrides driver supplied in connection ``extra``
:param driver: name of driver or path to driver. You can also set the driver via:
* setting ``driver`` parameter in ``hook_params`` dictionary when instantiating hook by SQL operators.
* setting `driver`` extra in the connection and setting ``allow_driver_extra`` to True.
* setting ``OdbcHook.default_driver`` in ``local_settings.py`` file.
:param dsn: name of DSN to use. overrides DSN supplied in connection ``extra``
:param connect_kwargs: keyword arguments passed to ``pyodbc.connect``
:param sqlalchemy_scheme: Scheme sqlalchemy connection. Default is ``mssql+pyodbc`` Only used for
``get_sqlalchemy_engine`` and ``get_sqlalchemy_connection`` methods.
:param allow_driver_extra: If True, allows to use driver extra in connection string (default False).
You should make sure that you trust the users who can edit connections in the UI to not use it
maliciously.
:param kwargs: passed to DbApiHook
"""

Expand All @@ -49,6 +55,8 @@ class OdbcHook(DbApiHook):
hook_name = "ODBC"
supports_autocommit = True

default_driver: str | None = None

def __init__(
self,
*args,
Expand All @@ -57,6 +65,7 @@ def __init__(
dsn: str | None = None,
connect_kwargs: dict | None = None,
sqlalchemy_scheme: str | None = None,
allow_driver_extra: bool = False,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
Expand All @@ -67,6 +76,7 @@ def __init__(
self._sqlalchemy_scheme = sqlalchemy_scheme
self._connection = None
self._connect_kwargs = connect_kwargs
self._allow_driver_extra = allow_driver_extra

@property
def connection(self):
Expand Down Expand Up @@ -101,11 +111,19 @@ def connection_extra_lower(self) -> dict:
@property
def driver(self) -> str | None:
"""Driver from init param if given; else try to find one in connection extra."""
extra_driver = self.connection_extra_lower.get("driver")
if extra_driver:
if self._allow_driver_extra:
self._driver = extra_driver
else:
self.log.warning(
"Please provide driver via 'driver' parameter of the Hook constructor"
" or via 'hook_params' dictionary 'driver' key when instantiating hook by the"
" SQL operators. The 'driver' extra will not be used."
)
if not self._driver:
driver = self.connection_extra_lower.get("driver")
if driver:
self._driver = driver
return self._driver and self._driver.strip().lstrip("{").rstrip("}").strip()
self._driver = self.default_driver
return self._driver.strip().lstrip("{").rstrip("}").strip() if self._driver else None

@property
def dsn(self) -> str | None:
Expand Down
44 changes: 38 additions & 6 deletions tests/providers/odbc/hooks/test_odbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from __future__ import annotations

import json
import logging
from unittest import mock
from unittest.mock import patch
from urllib.parse import quote_plus, urlsplit

import pyodbc
Expand All @@ -43,11 +45,12 @@ def get_hook(self=None, hook_params=None, conn_params=None):
hook.get_connection.return_value = connection
return hook

def test_driver_in_extra(self):
def test_driver_in_extra_not_used(self):
conn_params = dict(extra=json.dumps(dict(Driver="Fake Driver", Fake_Param="Fake Param")))
hook = self.get_hook(conn_params=conn_params)
hook_params = {"driver": "ParamDriver"}
hook = self.get_hook(conn_params=conn_params, hook_params=hook_params)
expected = (
"DRIVER={Fake Driver};"
"DRIVER={ParamDriver};"
"SERVER=host;"
"DATABASE=schema;"
"UID=login;"
Expand Down Expand Up @@ -177,11 +180,40 @@ def test_driver(self):
assert hook.driver == "Blah driver"
hook = self.get_hook(hook_params=dict(driver="{Blah driver}"))
assert hook.driver == "Blah driver"
hook = self.get_hook(conn_params=dict(extra='{"driver": "Blah driver"}'))
assert hook.driver == "Blah driver"
hook = self.get_hook(conn_params=dict(extra='{"driver": "{Blah driver}"}'))

def test_driver_extra_raises_warning_by_default(self, caplog):
with caplog.at_level(logging.WARNING, logger="airflow.providers.odbc.hooks.test_odbc"):
driver = self.get_hook(conn_params=dict(extra='{"driver": "Blah driver"}')).driver
assert "Please provide driver via 'driver' parameter of the Hook" in caplog.text
assert driver is None

def test_driver_extra_works_when_allow_driver_extra(self):
hook = self.get_hook(
conn_params=dict(extra='{"driver": "Blah driver"}'), hook_params=dict(allow_driver_extra=True)
)
assert hook.driver == "Blah driver"

def test_default_driver_set(self):
with patch.object(OdbcHook, "default_driver", "Blah driver"):
hook = self.get_hook()
assert hook.driver == "Blah driver"

def test_driver_extra_works_when_default_driver_set(self):
with patch.object(OdbcHook, "default_driver", "Blah driver"):
hook = self.get_hook()
assert hook.driver == "Blah driver"

def test_driver_none_by_default(self):
hook = self.get_hook()
assert hook.driver is None

def test_driver_extra_raises_warning_and_returns_default_driver_by_default(self, caplog):
with patch.object(OdbcHook, "default_driver", "Blah driver"):
with caplog.at_level(logging.WARNING, logger="airflow.providers.odbc.hooks.test_odbc"):
driver = self.get_hook(conn_params=dict(extra='{"driver": "Blah driver2"}')).driver
assert "Please provide driver via 'driver' parameter of the Hook" in caplog.text
assert driver == "Blah driver"

def test_database(self):
hook = self.get_hook(hook_params=dict(database="abc"))
assert hook.database == "abc"
Expand Down