Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar
from urllib.parse import urlparse

from airflow.providers.google.version_compat import (
AIRFLOW_V_3_0_PLUS,
Expand Down Expand Up @@ -109,6 +110,14 @@ def get_link(
if TYPE_CHECKING:
assert isinstance(operator, (GoogleCloudBaseOperator, BaseSensorOperator))

# In cases when worker passes execution to trigger, the value that is put to XCom
# already contains link to the object in string format. In this case we don't want to execute
# get_config() again. Instead we can leave this value without any changes
link_value = XCom.get_value(key=self.key, ti_key=ti_key)
if link_value and isinstance(link_value, str):
if urlparse(link_value).scheme in ("http", "https"):
return link_value

conf = self.get_config(operator, ti_key)
if not conf:
return ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk.execution_time.comms import XComResult
from airflow.sdk.execution_time.xcom import XCom
else:
from airflow.models.xcom import XCom # type: ignore[no-redef]

TEST_LOCATION = "test-location"
TEST_CLUSTER_ID = "test-cluster-id"
Expand Down Expand Up @@ -128,3 +131,73 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis
)
actual_url = link.get_link(operator=ti.task, ti_key=ti.key)
assert actual_url == expected_url

@pytest.mark.db_test
@mock.patch.object(XCom, "get_value")
def test_get_link_uses_xcom_url_and_skips_get_config(
self,
mock_get_value,
create_task_instance_of_operator,
session,
):
xcom_url = "https://console.cloud.google.com/some/service?project=test-proj"
mock_get_value.return_value = xcom_url

link = GoogleLink()
ti = create_task_instance_of_operator(
MyOperator,
dag_id="test_link_dag",
task_id="test_link_task",
location=TEST_LOCATION,
cluster_id=TEST_CLUSTER_ID,
project_id=TEST_PROJECT_ID,
)
session.add(ti)
session.commit()

with mock.patch.object(GoogleLink, "get_config", autospec=True) as m_get_config:
actual_url = link.get_link(operator=ti.task, ti_key=ti.key)

assert actual_url == xcom_url
m_get_config.assert_not_called()

@pytest.mark.db_test
@mock.patch.object(XCom, "get_value")
def test_get_link_falls_back_to_get_config_when_xcom_not_http(
self,
mock_get_value,
create_task_instance_of_operator,
session,
):
mock_get_value.return_value = "gs://bucket/path"

link = GoogleLink()
ti = create_task_instance_of_operator(
MyOperator,
dag_id="test_link_dag",
task_id="test_link_task",
location=TEST_LOCATION,
cluster_id=TEST_CLUSTER_ID,
project_id=TEST_PROJECT_ID,
)
session.add(ti)
session.commit()

expected_formatted = "https://console.cloud.google.com/expected/link?project=test-proj"
with (
mock.patch.object(
GoogleLink,
"get_config",
return_value={
"project_id": ti.task.project_id,
"location": ti.task.location,
"cluster_id": ti.task.cluster_id,
},
) as m_get_config,
mock.patch.object(GoogleLink, "_format_link", return_value=expected_formatted) as m_fmt,
):
actual_url = link.get_link(operator=ti.task, ti_key=ti.key)

assert actual_url == expected_formatted
m_get_config.assert_called_once()
m_fmt.assert_called_once()