From f2c3fb73a090bea06806fef562032a32d5c4c6db Mon Sep 17 00:00:00 2001 From: kalyanr Date: Mon, 8 Apr 2024 17:45:01 +0530 Subject: [PATCH 01/12] update get_uri --- airflow/providers/postgres/hooks/postgres.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index 0afb7740fe58f..c6769f4092c31 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -22,6 +22,7 @@ from contextlib import closing from copy import deepcopy from typing import TYPE_CHECKING, Any, Iterable, Union +from urllib.parse import quote import psycopg2 import psycopg2.extensions @@ -137,7 +138,7 @@ def get_conn(self) -> connection: conn_args = { "host": conn.host, "user": conn.login, - "password": conn.password, + "password": quote(conn.password), "dbname": self.database or conn.schema, "port": conn.port, } @@ -189,7 +190,7 @@ def get_uri(self) -> str: """ conn = self.get_connection(getattr(self, self.conn_name_attr)) conn.schema = self.database or conn.schema - uri = conn.get_uri().replace("postgres://", "postgresql://") + uri = f"postgresql://{conn.login}:{quote(conn.password)}@{conn.host}:{conn.port}/{conn.schema}" return uri def bulk_load(self, table: str, tmp_file: str) -> None: From fb4568ed932d6f8e0675cd1803216b50f88db0ca Mon Sep 17 00:00:00 2001 From: kalyanr Date: Mon, 8 Apr 2024 20:26:56 +0530 Subject: [PATCH 02/12] update get_uri Signed-off-by: kalyanr --- airflow/providers/postgres/hooks/postgres.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index c6769f4092c31..deefaefc838c3 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -138,7 +138,7 @@ def get_conn(self) -> connection: conn_args = { "host": conn.host, "user": conn.login, - "password": quote(conn.password), + "password": conn.password, "dbname": self.database or conn.schema, "port": conn.port, } @@ -190,6 +190,7 @@ def get_uri(self) -> str: """ conn = self.get_connection(getattr(self, self.conn_name_attr)) conn.schema = self.database or conn.schema + conn.port = conn.port or "5432" uri = f"postgresql://{conn.login}:{quote(conn.password)}@{conn.host}:{conn.port}/{conn.schema}" return uri From 718c42d9500766d6c434e1a805790a094914af94 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Mon, 8 Apr 2024 20:29:08 +0530 Subject: [PATCH 03/12] update docstring Signed-off-by: kalyanr --- airflow/providers/postgres/hooks/postgres.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index deefaefc838c3..794c9b0f44c19 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -186,7 +186,7 @@ def copy_expert(self, sql: str, filename: str) -> None: def get_uri(self) -> str: """Extract the URI from the connection. - :return: the extracted uri. + :return: the extracted URI in Sqlalchemy URI format. """ conn = self.get_connection(getattr(self, self.conn_name_attr)) conn.schema = self.database or conn.schema From 7bb8687db9ff5fd26a762dab0ddabcfcf082e118 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Mon, 8 Apr 2024 21:20:41 +0530 Subject: [PATCH 04/12] add and use sa_uri property --- airflow/providers/common/sql/hooks/sql.py | 5 +++++ airflow/providers/postgres/hooks/postgres.py | 20 ++++++++++++++------ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 3f324e4f697a3..3ff4183126296 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -45,6 +45,7 @@ if TYPE_CHECKING: from pandas import DataFrame + from sqlalchemy.engine import URL from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import DatabaseInfo @@ -174,6 +175,10 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa "replace_statement_format", "REPLACE INTO {} {} VALUES ({})" ) + @property + def sa_uri(self) -> URL: + raise NotImplementedError + @property def placeholder(self): conn = self.get_connection(getattr(self, self.conn_name_attr)) diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index 794c9b0f44c19..9f5e40c597c3f 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -22,13 +22,13 @@ from contextlib import closing from copy import deepcopy from typing import TYPE_CHECKING, Any, Iterable, Union -from urllib.parse import quote import psycopg2 import psycopg2.extensions import psycopg2.extras from deprecated import deprecated from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor +from sqlalchemy.engine import URL from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.common.sql.hooks.sql import DbApiHook @@ -113,6 +113,18 @@ def schema(self): def schema(self, value): self.database = value + @property + def sa_uri(self) -> URL: + conn = self.get_connection(getattr(self, self.conn_name_attr)) + return URL.create( + drivername="postgresql", + username=conn.login, + password=conn.password, + host=conn.host, + port=conn.port, + database=conn.schema, + ) + def _get_cursor(self, raw_cursor: str) -> CursorType: _cursor = raw_cursor.lower() cursor_types = { @@ -188,11 +200,7 @@ def get_uri(self) -> str: :return: the extracted URI in Sqlalchemy URI format. """ - conn = self.get_connection(getattr(self, self.conn_name_attr)) - conn.schema = self.database or conn.schema - conn.port = conn.port or "5432" - uri = f"postgresql://{conn.login}:{quote(conn.password)}@{conn.host}:{conn.port}/{conn.schema}" - return uri + return self.sa_uri.render_as_string(hide_password=False) def bulk_load(self, table: str, tmp_file: str) -> None: """Load a tab-delimited file into a database table.""" From bf48b0d620dd8793792a91d4a8621c1102fcc804 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Mon, 8 Apr 2024 21:23:30 +0530 Subject: [PATCH 05/12] update database in sa_uri --- airflow/providers/postgres/hooks/postgres.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index 9f5e40c597c3f..b7f1e2d4aa8a2 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -122,7 +122,7 @@ def sa_uri(self) -> URL: password=conn.password, host=conn.host, port=conn.port, - database=conn.schema, + database=self.database or conn.schema, ) def _get_cursor(self, raw_cursor: str) -> CursorType: From 45a51a6fb00845c645abde14dff0af2177489c49 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Tue, 9 Apr 2024 10:02:12 +0530 Subject: [PATCH 06/12] update tests --- airflow/providers/postgres/hooks/postgres.py | 4 ++++ tests/providers/postgres/hooks/test_postgres.py | 10 ++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index b7f1e2d4aa8a2..ae85eae84dca3 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -171,6 +171,10 @@ def get_conn(self) -> connection: ]: conn_args[arg_name] = arg_val + client_encoding = conn.extra_dejson.get("client_encoding") + if isinstance(client_encoding, str): + conn_args["client_encoding"] = client_encoding + self.conn = psycopg2.connect(**conn_args) return self.conn diff --git a/tests/providers/postgres/hooks/test_postgres.py b/tests/providers/postgres/hooks/test_postgres.py index 2d62cb4f43129..dc20edf2fe641 100644 --- a/tests/providers/postgres/hooks/test_postgres.py +++ b/tests/providers/postgres/hooks/test_postgres.py @@ -155,10 +155,16 @@ def test_get_conn_rds_iam_postgres(self, mock_aws_hook_class, mock_connect, aws_ @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") def test_get_conn_extra(self, mock_connect): - self.connection.extra = '{"connect_timeout": 3}' + self.connection.extra = json.dumps({"client_encoding": "utf-8", "connect_timeout": 3}) self.db_hook.get_conn() mock_connect.assert_called_once_with( - user="login", password="password", host="host", dbname="database", port=None, connect_timeout=3 + user="login", + password="password", + host="host", + dbname="database", + port=None, + connect_timeout=3, + client_encoding="utf-8", ) @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") From 2a41dff79c5014fcb0b1f8253e54206381e06afa Mon Sep 17 00:00:00 2001 From: kalyanr Date: Tue, 9 Apr 2024 10:25:44 +0530 Subject: [PATCH 07/12] remove client_encoding from test_get_uri --- tests/providers/postgres/hooks/test_postgres.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/providers/postgres/hooks/test_postgres.py b/tests/providers/postgres/hooks/test_postgres.py index dc20edf2fe641..a1a63c15b4f18 100644 --- a/tests/providers/postgres/hooks/test_postgres.py +++ b/tests/providers/postgres/hooks/test_postgres.py @@ -58,11 +58,10 @@ def test_get_conn(self, mock_connect): @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") def test_get_uri(self, mock_connect): - self.connection.extra = json.dumps({"client_encoding": "utf-8"}) self.connection.conn_type = "postgres" self.db_hook.get_conn() assert mock_connect.call_count == 1 - assert self.db_hook.get_uri() == "postgresql://login:password@host/database?client_encoding=utf-8" + assert self.db_hook.get_uri() == "postgresql://login:password@host/database" @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") def test_get_conn_cursor(self, mock_connect): From bf107336990383223f4c6dda542dcbcaa7c55451 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Sun, 14 Apr 2024 16:28:43 +0530 Subject: [PATCH 08/12] use sqlalchemy_url property --- airflow/providers/common/sql/hooks/sql.py | 4 ---- airflow/providers/postgres/hooks/postgres.py | 4 ++-- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 3277ed88737bb..070bf0647cfc6 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -181,10 +181,6 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa "replace_statement_format", "REPLACE INTO {} {} VALUES ({})" ) - @property - def sa_uri(self) -> URL: - raise NotImplementedError - @property def placeholder(self): conn = self.get_connection(getattr(self, self.conn_name_attr)) diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index b3010425cbd50..f0babf38868ba 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -115,7 +115,7 @@ def schema(self, value): self.database = value @property - def sa_uri(self) -> URL: + def sqlalchemy_url(self) -> URL: conn = self.get_connection(getattr(self, self.conn_name_attr)) return URL.create( drivername="postgresql", @@ -205,7 +205,7 @@ def get_uri(self) -> str: :return: the extracted URI in Sqlalchemy URI format. """ - return self.sa_uri.render_as_string(hide_password=False) + return self.sqlalchemy_url.render_as_string(hide_password=False) def bulk_load(self, table: str, tmp_file: str) -> None: """Load a tab-delimited file into a database table.""" From 6eea77d63379713c06075ad55fb7fa53b83b29bd Mon Sep 17 00:00:00 2001 From: kalyanr Date: Sun, 14 Apr 2024 16:33:05 +0530 Subject: [PATCH 09/12] add default port --- airflow/providers/postgres/hooks/postgres.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index f0babf38868ba..dfb05cb08f62f 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -122,7 +122,7 @@ def sqlalchemy_url(self) -> URL: username=conn.login, password=conn.password, host=conn.host, - port=conn.port, + port=conn.port or 5432, database=self.database or conn.schema, ) From b332c7898f6980af2f9bd035ae07ee7c26f4a711 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Sun, 14 Apr 2024 16:48:33 +0530 Subject: [PATCH 10/12] update tests --- tests/providers/postgres/hooks/test_postgres.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/postgres/hooks/test_postgres.py b/tests/providers/postgres/hooks/test_postgres.py index d19434afc8021..b54ff62a38279 100644 --- a/tests/providers/postgres/hooks/test_postgres.py +++ b/tests/providers/postgres/hooks/test_postgres.py @@ -61,7 +61,7 @@ def test_get_uri(self, mock_connect): self.connection.conn_type = "postgres" self.db_hook.get_conn() assert mock_connect.call_count == 1 - assert self.db_hook.get_uri() == "postgresql://login:password@host/database" + assert self.db_hook.get_uri() == "postgresql://login:password@host:5432/database" @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") def test_get_conn_cursor(self, mock_connect): From ed4c2a226628c60f0cdcc1343808332909b26fbd Mon Sep 17 00:00:00 2001 From: kalyanr Date: Sun, 14 Apr 2024 17:13:55 +0530 Subject: [PATCH 11/12] update usage of ports --- airflow/providers/postgres/hooks/postgres.py | 2 +- tests/providers/postgres/hooks/test_postgres.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index dfb05cb08f62f..f0babf38868ba 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -122,7 +122,7 @@ def sqlalchemy_url(self) -> URL: username=conn.login, password=conn.password, host=conn.host, - port=conn.port or 5432, + port=conn.port, database=self.database or conn.schema, ) diff --git a/tests/providers/postgres/hooks/test_postgres.py b/tests/providers/postgres/hooks/test_postgres.py index b54ff62a38279..b7f97f2860ea7 100644 --- a/tests/providers/postgres/hooks/test_postgres.py +++ b/tests/providers/postgres/hooks/test_postgres.py @@ -59,6 +59,7 @@ def test_get_conn(self, mock_connect): @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") def test_get_uri(self, mock_connect): self.connection.conn_type = "postgres" + self.connection.port = 5432 self.db_hook.get_conn() assert mock_connect.call_count == 1 assert self.db_hook.get_uri() == "postgresql://login:password@host:5432/database" From 28801e83cd559e649260343dd365e218163b3225 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Mon, 6 May 2024 18:26:07 +0530 Subject: [PATCH 12/12] revert client_encoding updates --- airflow/providers/postgres/hooks/postgres.py | 4 ---- tests/providers/postgres/hooks/test_postgres.py | 10 ++-------- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index f0babf38868ba..6a24cb03160d8 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -172,10 +172,6 @@ def get_conn(self) -> connection: ]: conn_args[arg_name] = arg_val - client_encoding = conn.extra_dejson.get("client_encoding") - if isinstance(client_encoding, str): - conn_args["client_encoding"] = client_encoding - self.conn = psycopg2.connect(**conn_args) return self.conn diff --git a/tests/providers/postgres/hooks/test_postgres.py b/tests/providers/postgres/hooks/test_postgres.py index b7f97f2860ea7..78d3414ab0a53 100644 --- a/tests/providers/postgres/hooks/test_postgres.py +++ b/tests/providers/postgres/hooks/test_postgres.py @@ -155,16 +155,10 @@ def test_get_conn_rds_iam_postgres(self, mock_aws_hook_class, mock_connect, aws_ @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") def test_get_conn_extra(self, mock_connect): - self.connection.extra = json.dumps({"client_encoding": "utf-8", "connect_timeout": 3}) + self.connection.extra = '{"connect_timeout": 3}' self.db_hook.get_conn() mock_connect.assert_called_once_with( - user="login", - password="password", - host="host", - dbname="database", - port=None, - connect_timeout=3, - client_encoding="utf-8", + user="login", password="password", host="host", dbname="database", port=None, connect_timeout=3 ) @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")