From 5aebb1b18af0f7ac4d4a81aa53ace73ca0d84319 Mon Sep 17 00:00:00 2001 From: josix Date: Tue, 9 Apr 2024 18:05:29 +0800 Subject: [PATCH] fix(airbyte/hooks): add schema and port to prevent InvalidURL error --- airflow/providers/airbyte/hooks/airbyte.py | 7 ++++++- tests/providers/airbyte/hooks/test_airbyte.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/airflow/providers/airbyte/hooks/airbyte.py b/airflow/providers/airbyte/hooks/airbyte.py index 7c1132fa57183..1cecb5b031ec6 100644 --- a/airflow/providers/airbyte/hooks/airbyte.py +++ b/airflow/providers/airbyte/hooks/airbyte.py @@ -71,7 +71,12 @@ def __init__( async def get_headers_tenants_from_connection(self) -> tuple[dict[str, Any], str]: """Get Headers, tenants from the connection details.""" connection: Connection = await sync_to_async(self.get_connection)(self.http_conn_id) - base_url = connection.host + # schema defaults to HTTP + schema = connection.schema if connection.schema else "http" + base_url = f"{schema}://{connection.host}" + + if connection.port: + base_url += f":{connection.port}" if self.api_type == "config": credentials = f"{connection.login}:{connection.password}" diff --git a/tests/providers/airbyte/hooks/test_airbyte.py b/tests/providers/airbyte/hooks/test_airbyte.py index 18d935bda352a..6cf211909ef13 100644 --- a/tests/providers/airbyte/hooks/test_airbyte.py +++ b/tests/providers/airbyte/hooks/test_airbyte.py @@ -67,6 +67,23 @@ def test_submit_sync_connection(self, requests_mock): assert resp.status_code == 200 assert resp.json() == self._mock_sync_conn_success_response_body + @pytest.mark.asyncio + @pytest.mark.parametrize( + "host, port, schema, expected_base_url, description", + [ + ("test-airbyte", 8001, "http", "http://test-airbyte:8001", "uri_with_port_and_schema"), + ("test-airbyte", None, "https", "https://test-airbyte", "uri_with_schema"), + ("test-airbyte", None, None, "http://test-airbyte", "uri_without_port_and_schema"), + ], + ) + async def test_get_base_url(self, host, port, schema, expected_base_url, description): + conn_id = f"test_conn_{description}" + conn = Connection(conn_id=conn_id, conn_type="airbyte", host=host, port=port, schema=schema) + hook = AirbyteHook(airbyte_conn_id=conn_id) + db.merge_conn(conn) + _, base_url = await hook.get_headers_tenants_from_connection() + assert base_url == expected_base_url + def test_get_job_status(self, requests_mock): requests_mock.post( self.get_job_endpoint, status_code=200, json=self._mock_job_status_success_response_body