diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index 9e1b3a83d7282..6a24cb03160d8 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -28,6 +28,7 @@ import psycopg2.extras from deprecated import deprecated from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor +from sqlalchemy.engine import URL from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.common.sql.hooks.sql import DbApiHook @@ -113,6 +114,18 @@ def schema(self): def schema(self, value): self.database = value + @property + def sqlalchemy_url(self) -> URL: + conn = self.get_connection(getattr(self, self.conn_name_attr)) + return URL.create( + drivername="postgresql", + username=conn.login, + password=conn.password, + host=conn.host, + port=conn.port, + database=self.database or conn.schema, + ) + def _get_cursor(self, raw_cursor: str) -> CursorType: _cursor = raw_cursor.lower() cursor_types = { @@ -186,12 +199,9 @@ def copy_expert(self, sql: str, filename: str) -> None: def get_uri(self) -> str: """Extract the URI from the connection. - :return: the extracted uri. + :return: the extracted URI in Sqlalchemy URI format. """ - conn = self.get_connection(getattr(self, self.conn_name_attr)) - conn.schema = self.database or conn.schema - uri = conn.get_uri().replace("postgres://", "postgresql://") - return uri + return self.sqlalchemy_url.render_as_string(hide_password=False) def bulk_load(self, table: str, tmp_file: str) -> None: """Load a tab-delimited file into a database table.""" diff --git a/tests/providers/postgres/hooks/test_postgres.py b/tests/providers/postgres/hooks/test_postgres.py index 8330ad3b1d53f..78d3414ab0a53 100644 --- a/tests/providers/postgres/hooks/test_postgres.py +++ b/tests/providers/postgres/hooks/test_postgres.py @@ -58,11 +58,11 @@ def test_get_conn(self, mock_connect): @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") def test_get_uri(self, mock_connect): - self.connection.extra = json.dumps({"client_encoding": "utf-8"}) self.connection.conn_type = "postgres" + self.connection.port = 5432 self.db_hook.get_conn() assert mock_connect.call_count == 1 - assert self.db_hook.get_uri() == "postgresql://login:password@host/database?client_encoding=utf-8" + assert self.db_hook.get_uri() == "postgresql://login:password@host:5432/database" @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") def test_get_conn_cursor(self, mock_connect):