diff --git a/providers/google/src/airflow/providers/google/cloud/links/base.py b/providers/google/src/airflow/providers/google/cloud/links/base.py index ea6d8850c95f5..9136b377910e3 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/base.py +++ b/providers/google/src/airflow/providers/google/cloud/links/base.py @@ -18,6 +18,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, ClassVar +from urllib.parse import urlparse from airflow.providers.google.version_compat import ( AIRFLOW_V_3_0_PLUS, @@ -109,6 +110,14 @@ def get_link( if TYPE_CHECKING: assert isinstance(operator, (GoogleCloudBaseOperator, BaseSensorOperator)) + # In cases when worker passes execution to trigger, the value that is put to XCom + # already contains link to the object in string format. In this case we don't want to execute + # get_config() again. Instead we can leave this value without any changes + link_value = XCom.get_value(key=self.key, ti_key=ti_key) + if link_value and isinstance(link_value, str): + if urlparse(link_value).scheme in ("http", "https"): + return link_value + conf = self.get_config(operator, ti_key) if not conf: return "" diff --git a/providers/google/tests/unit/google/cloud/links/test_base_link.py b/providers/google/tests/unit/google/cloud/links/test_base_link.py index b74963d916ed3..6440eb865391d 100644 --- a/providers/google/tests/unit/google/cloud/links/test_base_link.py +++ b/providers/google/tests/unit/google/cloud/links/test_base_link.py @@ -29,6 +29,9 @@ if AIRFLOW_V_3_0_PLUS: from airflow.sdk.execution_time.comms import XComResult + from airflow.sdk.execution_time.xcom import XCom +else: + from airflow.models.xcom import XCom # type: ignore[no-redef] TEST_LOCATION = "test-location" TEST_CLUSTER_ID = "test-cluster-id" @@ -128,3 +131,73 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis ) actual_url = link.get_link(operator=ti.task, ti_key=ti.key) assert actual_url == expected_url + + @pytest.mark.db_test + @mock.patch.object(XCom, "get_value") + def test_get_link_uses_xcom_url_and_skips_get_config( + self, + mock_get_value, + create_task_instance_of_operator, + session, + ): + xcom_url = "https://console.cloud.google.com/some/service?project=test-proj" + mock_get_value.return_value = xcom_url + + link = GoogleLink() + ti = create_task_instance_of_operator( + MyOperator, + dag_id="test_link_dag", + task_id="test_link_task", + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + project_id=TEST_PROJECT_ID, + ) + session.add(ti) + session.commit() + + with mock.patch.object(GoogleLink, "get_config", autospec=True) as m_get_config: + actual_url = link.get_link(operator=ti.task, ti_key=ti.key) + + assert actual_url == xcom_url + m_get_config.assert_not_called() + + @pytest.mark.db_test + @mock.patch.object(XCom, "get_value") + def test_get_link_falls_back_to_get_config_when_xcom_not_http( + self, + mock_get_value, + create_task_instance_of_operator, + session, + ): + mock_get_value.return_value = "gs://bucket/path" + + link = GoogleLink() + ti = create_task_instance_of_operator( + MyOperator, + dag_id="test_link_dag", + task_id="test_link_task", + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + project_id=TEST_PROJECT_ID, + ) + session.add(ti) + session.commit() + + expected_formatted = "https://console.cloud.google.com/expected/link?project=test-proj" + with ( + mock.patch.object( + GoogleLink, + "get_config", + return_value={ + "project_id": ti.task.project_id, + "location": ti.task.location, + "cluster_id": ti.task.cluster_id, + }, + ) as m_get_config, + mock.patch.object(GoogleLink, "_format_link", return_value=expected_formatted) as m_fmt, + ): + actual_url = link.get_link(operator=ti.task, ti_key=ti.key) + + assert actual_url == expected_formatted + m_get_config.assert_called_once() + m_fmt.assert_called_once()