Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions airflow/api_connexion/endpoints/extra_link_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def get_extra_links(
dag_id: str,
dag_run_id: str,
task_id: str,
map_index: int = -1,
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get extra links for task instance."""
Expand All @@ -62,6 +63,7 @@ def get_extra_links(
TaskInstance.dag_id == dag_id,
TaskInstance.run_id == dag_run_id,
TaskInstance.task_id == task_id,
TaskInstance.map_index == map_index,
)
)

Expand Down
1 change: 1 addition & 0 deletions airflow/api_connexion/openapi/v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2062,6 +2062,7 @@ paths:
- $ref: "#/components/parameters/DAGID"
- $ref: "#/components/parameters/DAGRunID"
- $ref: "#/components/parameters/TaskID"
- $ref: "#/components/parameters/FilterMapIndex"

get:
summary: List extra links
Expand Down
11 changes: 10 additions & 1 deletion airflow/www/static/js/types/api-generated.ts
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,10 @@ export interface paths {
/** The task ID. */
task_id: components["parameters"]["TaskID"];
};
query: {
/** Filter on map index for mapped task. */
map_index?: components["parameters"]["FilterMapIndex"];
};
};
};
"/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/logs/{task_try_number}": {
Expand Down Expand Up @@ -4741,6 +4745,10 @@ export interface operations {
/** The task ID. */
task_id: components["parameters"]["TaskID"];
};
query: {
/** Filter on map index for mapped task. */
map_index?: components["parameters"]["FilterMapIndex"];
};
};
responses: {
/** Success. */
Expand Down Expand Up @@ -5990,7 +5998,8 @@ export type GetXcomEntryVariables = CamelCasedPropertiesDeep<
operations["get_xcom_entry"]["parameters"]["query"]
>;
export type GetExtraLinksVariables = CamelCasedPropertiesDeep<
operations["get_extra_links"]["parameters"]["path"]
operations["get_extra_links"]["parameters"]["path"] &
operations["get_extra_links"]["parameters"]["query"]
>;
export type GetLogVariables = CamelCasedPropertiesDeep<
operations["get_log"]["parameters"]["path"] &
Expand Down
64 changes: 62 additions & 2 deletions tests/api_connexion/endpoints/test_extra_link_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from airflow.models.xcom import XCom
from airflow.plugins_manager import AirflowPlugin
from airflow.security import permissions
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.timetables.base import DataInterval
from airflow.utils import timezone
from airflow.utils.state import DagRunState
Expand Down Expand Up @@ -62,7 +63,7 @@ def configured_app(minimal_app_for_api):
delete_user(app, username="test_no_permissions") # type: ignore


class TestGetExtraLinks:
class BaseGetExtraLinks:
@pytest.fixture(autouse=True)
def setup_attrs(self, configured_app, session) -> None:
self.default_time = timezone.datetime(2020, 1, 1)
Expand All @@ -72,7 +73,7 @@ def setup_attrs(self, configured_app, session) -> None:

self.app = configured_app

self.dag = self._create_dag()
self.dag = self._create_dag() # type: ignore

self.app.dag_bag = DagBag(os.devnull, include_examples=False)
self.app.dag_bag.dags = {self.dag.dag_id: self.dag} # type: ignore
Expand All @@ -94,6 +95,8 @@ def teardown_method(self) -> None:
clear_db_runs()
clear_db_xcom()


class TestGetExtraLinks(BaseGetExtraLinks):
def _create_dag(self):
with DAG(dag_id="TEST_DAG_ID", schedule=None, default_args={"start_date": self.default_time}) as dag:
CustomOperator(task_id="TEST_SINGLE_LINK", bash_command="TEST_LINK_VALUE")
Expand Down Expand Up @@ -241,3 +244,60 @@ class AirflowTestPlugin(AirflowPlugin):
"TEST_DAG_ID/TEST_SINGLE_LINK/2020-01-01T00%3A00%3A00%2B00%3A00"
),
} == response.json


class TestMappedTaskExtraLinks(BaseGetExtraLinks):
def _create_dag(self):
with DAG(dag_id="TEST_DAG_ID", schedule=None, default_args={"start_date": self.default_time}) as dag:
# Mapped task expanded over a list of bash_commands
CustomOperator.partial(task_id="TEST_MAPPED_TASK").expand(
bash_command=["TEST_LINK_VALUE_3", "TEST_LINK_VALUE_4"]
)
return SerializedBaseOperator.deserialize(SerializedBaseOperator.serialize(dag))

@pytest.mark.parametrize(
"map_index, expected_status, expected_json",
[
(
0,
200,
{
"Google Custom": "http://google.com/custom_base_link?search=TEST_LINK_VALUE_3",
"google": "https://www.google.com",
},
),
(
1,
200,
{
"Google Custom": "http://google.com/custom_base_link?search=TEST_LINK_VALUE_4",
"google": "https://www.google.com",
},
),
(6, 404, {"detail": 'DAG Run with ID = "TEST_DAG_RUN_ID" not found'}),
],
)
@mock_plugin_manager(plugins=[])
def test_mapped_task_links(self, map_index, expected_status, expected_json):
"""Parameterized test for mapped task extra links."""
# Set XCom data for different map indices
if map_index < 2:
XCom.set(
key="search_query",
value=f"TEST_LINK_VALUE_{map_index + 3}",
task_id="TEST_MAPPED_TASK",
dag_id="TEST_DAG_ID",
run_id="TEST_DAG_RUN_ID",
map_index=map_index,
)

response = self.client.get(
f"/api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/TEST_MAPPED_TASK/links?map_index={map_index}",
environ_overrides={"REMOTE_USER": "test"},
)

assert response.status_code == expected_status
if map_index < 2:
assert response.json == expected_json
else:
assert response.json["detail"] == expected_json["detail"]
13 changes: 11 additions & 2 deletions tests/test_utils/mock_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import attr

from airflow.models.baseoperator import BaseOperator
from airflow.models.mappedoperator import MappedOperator
from airflow.models.xcom import XCom
from tests.test_utils.compat import BaseOperatorLink

Expand Down Expand Up @@ -137,7 +138,11 @@ class CustomOpLink(BaseOperatorLink):

def get_link(self, operator, *, ti_key):
search_query = XCom.get_one(
task_id=ti_key.task_id, dag_id=ti_key.dag_id, run_id=ti_key.run_id, key="search_query"
task_id=ti_key.task_id,
dag_id=ti_key.dag_id,
run_id=ti_key.run_id,
map_index=ti_key.map_index,
key="search_query",
)
if not search_query:
return None
Expand All @@ -153,7 +158,11 @@ def operator_extra_links(self):
"""
Return operator extra links
"""
if isinstance(self.bash_command, str) or self.bash_command is None:
if (
isinstance(self, MappedOperator)
or isinstance(self.bash_command, str)
or self.bash_command is None
):
return (CustomOpLink(),)
return (CustomBaseIndexOpLink(i) for i, _ in enumerate(self.bash_command))

Expand Down