Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
98d07ca
AIP-72: Changing extra links interface for task SDK
amoghrajesh Feb 10, 2025
bbf5ffd
fixing circular import
amoghrajesh Feb 10, 2025
478182a
getting plugins working yay
amoghrajesh Feb 10, 2025
a74445e
adding task runner test case
amoghrajesh Feb 10, 2025
f97d89e
removing print and adding comment
amoghrajesh Feb 10, 2025
967a54c
fixing tests and adding a default
amoghrajesh Feb 10, 2025
fb9cdcd
adding newsfragments
amoghrajesh Feb 10, 2025
479d9cc
removing xcom_key from links
amoghrajesh Feb 10, 2025
9e7e5e0
renaming newsfragments to feature
amoghrajesh Feb 10, 2025
626d13c
fixing unit tests
amoghrajesh Feb 10, 2025
99526f2
simpler nits from ash
amoghrajesh Feb 10, 2025
e84ab9e
harder review comments: moved the get_extra_links to serialisedBaseOp…
amoghrajesh Feb 11, 2025
1852c62
review comments from TP
amoghrajesh Feb 11, 2025
e15f58e
fixing core tests
amoghrajesh Feb 11, 2025
e40af66
fixing core tests 2
amoghrajesh Feb 11, 2025
7f754ca
fixing providers tests due to the interface change
amoghrajesh Feb 11, 2025
bbec2c0
Merge branch 'main' into AIP72-extra-links
amoghrajesh Feb 11, 2025
391d54b
partly fixing serialisation tests
amoghrajesh Feb 11, 2025
5f19c3f
fixing serialisation tests
amoghrajesh Feb 11, 2025
2e1e905
moving the functions using plugins manager from sdk => ser module
amoghrajesh Feb 11, 2025
a2a0cc5
fixing UTs for task runner
amoghrajesh Feb 11, 2025
32894b1
nits from ash
amoghrajesh Feb 11, 2025
90141a8
fixing tests in core
amoghrajesh Feb 12, 2025
ef9428e
adding to ini
amoghrajesh Feb 12, 2025
3446d61
Merge branch 'main' into AIP72-extra-links
amoghrajesh Feb 12, 2025
f2909b6
fixing the final failing test
amoghrajesh Feb 12, 2025
153c7e5
Merge branch 'main' into AIP72-extra-links
amoghrajesh Feb 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 0 additions & 62 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,12 @@
from __future__ import annotations

import datetime
import inspect
from collections.abc import Iterable, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable

from sqlalchemy import select

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.sdk.definitions._internal.abstractoperator import (
AbstractOperator as TaskSDKAbstractOperator,
NotMapped as NotMapped, # Re-export this for compat
Expand All @@ -42,7 +39,6 @@
if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.dag import DAG as SchedulerDAG
from airflow.models.taskinstance import TaskInstance
from airflow.sdk.definitions.baseoperator import BaseOperator
Expand Down Expand Up @@ -157,64 +153,6 @@ def priority_weight_total(self) -> int:
)
)

@cached_property
def operator_extra_link_dict(self) -> dict[str, Any]:
"""Returns dictionary of all extra links for the operator."""
op_extra_links_from_plugin: dict[str, Any] = {}
from airflow import plugins_manager

plugins_manager.initialize_extra_operators_links_plugins()
if plugins_manager.operator_extra_links is None:
raise AirflowException("Can't load operators")
for ope in plugins_manager.operator_extra_links:
if ope.operators and self.operator_class in ope.operators:
op_extra_links_from_plugin.update({ope.name: ope})

operator_extra_links_all = {link.name: link for link in self.operator_extra_links}
# Extra links defined in Plugins overrides operator links defined in operator
operator_extra_links_all.update(op_extra_links_from_plugin)

return operator_extra_links_all

@cached_property
def global_operator_extra_link_dict(self) -> dict[str, Any]:
"""Returns dictionary of all global extra links."""
from airflow import plugins_manager

plugins_manager.initialize_extra_operators_links_plugins()
if plugins_manager.global_operator_extra_links is None:
raise AirflowException("Can't load operators")
return {link.name: link for link in plugins_manager.global_operator_extra_links}

@cached_property
def extra_links(self) -> list[str]:
return sorted(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict))

def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None:
"""
For an operator, gets the URLs that the ``extra_links`` entry points to.

:meta private:

:raise ValueError: The error message of a ValueError will be passed on through to
the fronted to show up as a tooltip on the disabled link.
:param ti: The TaskInstance for the URL being searched for.
:param link_name: The name of the link we're looking for the URL for. Should be
one of the options specified in ``extra_links``.
"""
link: BaseOperatorLink | None = self.operator_extra_link_dict.get(link_name)
if not link:
link = self.global_operator_extra_link_dict.get(link_name)
if not link:
return None

parameters = inspect.signature(link.get_link).parameters
old_signature = all(name != "ti_key" for name, p in parameters.items() if p.kind != p.VAR_KEYWORD)

if old_signature:
return link.get_link(self.unmap(None), ti.dag_run.logical_date) # type: ignore[misc]
return link.get_link(self.unmap(None), ti_key=ti.key)

def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]:
"""
Create the mapped task instances for mapped task.
Expand Down
54 changes: 52 additions & 2 deletions airflow/models/baseoperatorlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,53 @@
from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, ClassVar

import attr
import attrs

from airflow.models.xcom import BaseXCom
from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
from airflow.models.baseoperator import BaseOperator
from airflow.models.taskinstancekey import TaskInstanceKey


@attr.s(auto_attribs=True)
@attrs.define()
class XComOperatorLink(LoggingMixin):
"""A generic operator link class that can retrieve link only using XCOMs. Used while deserializing operators."""

name: str
xcom_key: str

def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str:
"""
Retrieve the link from the XComs.

:param operator: The Airflow operator object this link is associated to.
:param ti_key: TaskInstance ID to return link for.
:return: link to external system, but by pulling it from XComs
"""
self.log.info(
"Attempting to retrieve link from XComs with key: %s for task id: %s", self.xcom_key, ti_key
)
value = BaseXCom.get_one(
key=self.xcom_key,
run_id=ti_key.run_id,
dag_id=ti_key.dag_id,
task_id=ti_key.task_id,
map_index=ti_key.map_index,
)
if not value:
self.log.debug(
"No link with name: %s present in XCom as key: %s, returning empty link",
self.name,
self.xcom_key,
)
return ""
# Stripping is a temporary workaround till https://github.com/apache/airflow/issues/46513 is handled.
return value.strip('"')


@attrs.define()
class BaseOperatorLink(metaclass=ABCMeta):
"""Abstract base class that defines how we get an operator link."""

Expand All @@ -44,6 +83,17 @@ class BaseOperatorLink(metaclass=ABCMeta):
def name(self) -> str:
"""Name of the link. This will be the button name on the task UI."""

@property
def xcom_key(self) -> str:
"""
XCom key with while the whole "link" for this operator link is stored.

On retrieving with this key, the entire link is returned.

Defaults to `_link_<class name>` if not provided.
"""
return f"_link_{self.__class__.__name__}"

@abstractmethod
def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str:
"""
Expand Down
137 changes: 77 additions & 60 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@
from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest
from airflow.exceptions import AirflowException, SerializationError, TaskDeferred
from airflow.models.baseoperator import BaseOperator
from airflow.models.baseoperatorlink import BaseOperatorLink, XComOperatorLink
from airflow.models.connection import Connection
from airflow.models.dag import DAG, _get_model_data_interval
from airflow.models.expandinput import (
EXPAND_INPUT_EMPTY,
create_expand_input,
)
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg
from airflow.providers_manager import ProvidersManager
Expand Down Expand Up @@ -96,7 +97,6 @@
from inspect import Parameter

from airflow.models import DagRun
from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.expandinput import ExpandInput
from airflow.sdk.definitions._internal.node import DAGNode
from airflow.sdk.types import Operator
Expand Down Expand Up @@ -1166,6 +1166,58 @@ def __init__(self, *args, **kwargs):
self.template_fields = BaseOperator.template_fields
self.operator_extra_links = BaseOperator.operator_extra_links

@cached_property
def operator_extra_link_dict(self) -> dict[str, BaseOperatorLink]:
"""Returns dictionary of all extra links for the operator."""
op_extra_links_from_plugin: dict[str, Any] = {}
from airflow import plugins_manager

plugins_manager.initialize_extra_operators_links_plugins()
if plugins_manager.operator_extra_links is None:
raise AirflowException("Can't load operators")
for ope in plugins_manager.operator_extra_links:
if ope.operators and self.operator_class in ope.operators:
op_extra_links_from_plugin.update({ope.name: ope})

operator_extra_links_all = {link.name: link for link in self.operator_extra_links}
# Extra links defined in Plugins overrides operator links defined in operator
operator_extra_links_all.update(op_extra_links_from_plugin)

return operator_extra_links_all

@cached_property
def global_operator_extra_link_dict(self) -> dict[str, Any]:
"""Returns dictionary of all global extra links."""
from airflow import plugins_manager

plugins_manager.initialize_extra_operators_links_plugins()
if plugins_manager.global_operator_extra_links is None:
raise AirflowException("Can't load operators")
return {link.name: link for link in plugins_manager.global_operator_extra_links}

@cached_property
def extra_links(self) -> list[str]:
return sorted(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict))

def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None:
"""
For an operator, gets the URLs that the ``extra_links`` entry points to.

:meta private:

:raise ValueError: The error message of a ValueError will be passed on through to
the fronted to show up as a tooltip on the disabled link.
:param ti: The TaskInstance for the URL being searched for.
:param link_name: The name of the link we're looking for the URL for. Should be
one of the options specified in ``extra_links``.
"""
link = self.operator_extra_link_dict.get(link_name)
if not link:
link = self.global_operator_extra_link_dict.get(link_name)
if not link:
return None
return link.get_link(self.unmap(None), ti_key=ti.key)

@property
def task_type(self) -> str:
# Overwrites task_type of BaseOperator to use _task_type instead of
Expand Down Expand Up @@ -1503,7 +1555,9 @@ def _is_excluded(cls, var: Any, attrname: str, op: DAGNode):
return super()._is_excluded(var, attrname, op)

@classmethod
def _deserialize_operator_extra_links(cls, encoded_op_links: list) -> dict[str, BaseOperatorLink]:
def _deserialize_operator_extra_links(
cls, encoded_op_links: dict[str, str]
) -> dict[str, XComOperatorLink]:
"""
Deserialize Operator Links if the Classes are registered in Airflow Plugins.

Expand All @@ -1520,77 +1574,40 @@ def _deserialize_operator_extra_links(cls, encoded_op_links: list) -> dict[str,
raise AirflowException("Can't load plugins")
op_predefined_extra_links = {}

for _operator_links_source in encoded_op_links:
# Get the key, value pair as Tuple where key is OperatorLink ClassName
# and value is the dictionary containing the arguments passed to the OperatorLink
#
# Example of a single iteration:
#
# _operator_links_source =
# {
# 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink': {
# 'index': 0
# }
# },
#
# list(_operator_links_source.items()) =
# [
# (
# 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink',
# {'index': 0}
# )
# ]
for name, xcom_key in encoded_op_links.items():
# Get the name and xcom_key of the encoded operator and use it to create a XComOperatorLink object
# during deserialization.
#
# list(_operator_links_source.items())[0] =
# (
# 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink',
# {
# 'index': 0
# }
# )

_operator_link_class_path, data = next(iter(_operator_links_source.items()))
if _operator_link_class_path in get_operator_extra_links():
single_op_link_class = import_string(_operator_link_class_path)
elif _operator_link_class_path in plugins_manager.registered_operator_link_classes:
single_op_link_class = plugins_manager.registered_operator_link_classes[
_operator_link_class_path
]
else:
log.error("Operator Link class %r not registered", _operator_link_class_path)
return {}

op_link_parameters = {param: cls.deserialize(value) for param, value in data.items()}
op_predefined_extra_link: BaseOperatorLink = single_op_link_class(**op_link_parameters)

# Example:
# enc_operator['_operator_extra_links'] =
# {
# 'airflow': 'airflow_link_key',
# 'foo-bar': 'link-key',
# 'no_response': 'key',
# 'raise_error': 'key'
# }

op_predefined_extra_link = XComOperatorLink(name=name, xcom_key=xcom_key)
op_predefined_extra_links.update({op_predefined_extra_link.name: op_predefined_extra_link})

return op_predefined_extra_links

@classmethod
def _serialize_operator_extra_links(cls, operator_extra_links: Iterable[BaseOperatorLink]):
def _serialize_operator_extra_links(
cls, operator_extra_links: Iterable[BaseOperatorLink]
) -> dict[str, str]:
"""
Serialize Operator Links.

Store the import path of the OperatorLink and the arguments passed to it.
Store the "name" of the link mapped with the xcom_key which can be later used to retrieve this
operator extra link from XComs.
For example:
``[{'airflow.providers.google.cloud.links.bigquery.BigQueryDatasetLink': {}}]``
``{'link-name-1': 'xcom-key-1'}``

:param operator_extra_links: Operator Link
:return: Serialized Operator Link
"""
serialize_operator_extra_links = []
for operator_extra_link in operator_extra_links:
op_link_arguments = {
param: cls.serialize(value) for param, value in attrs.asdict(operator_extra_link).items()
}

module_path = (
f"{operator_extra_link.__class__.__module__}.{operator_extra_link.__class__.__name__}"
)
serialize_operator_extra_links.append({module_path: op_link_arguments})

return serialize_operator_extra_links
return {link.name: link.xcom_key for link in operator_extra_links}

@classmethod
def serialize(cls, var: Any, *, strict: bool = False) -> Any:
Expand Down
1 change: 1 addition & 0 deletions newsfragments/46613.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Operator Links interface changed to not run user code in Airflow Webserver The Operator Extra links, which can be defined either via plugins or custom operators now do not execute any user code in the Airflow Webserver, but instead push the "full" links to XCom backend and the value is again fetched from the XCom backend when viewing task details in grid view.
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,7 @@ def assert_extra_link_url(
)

error_msg = f"{self.full_qualname!r} should be preserved after execution"
assert ti.task.get_extra_links(ti, self.link_class.name) == expected_url, error_msg

serialized_dag = self.dag_maker.get_serialized_data()
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
deserialized_task = deserialized_dag.task_dict[self.task_id]

error_msg = f"{self.full_qualname!r} should be preserved in deserialized tasks after execution"
assert deserialized_task.get_extra_links(ti, self.link_class.name) == expected_url, error_msg
assert task.operator_extra_links[0].get_link(operator=task, ti_key=ti.key) == expected_url, error_msg

def test_link_serialize(self):
"""Test: Operator links should exist for serialized DAG."""
Expand All @@ -223,7 +216,7 @@ def test_empty_xcom(self):
deserialized_task = deserialized_dag.task_dict[self.task_id]

assert (
ti.task.get_extra_links(ti, self.link_class.name) == ""
ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) == ""
), "Operator link should only be added if job id is available in XCom"

assert (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def test_run_job_operator_link(self, conn_id, account_id, create_task_instance_o

ti.xcom_push(key="job_run_url", value=_run_response["data"]["href"])

url = ti.task.get_extra_links(ti, "Monitor Job Run")
url = ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key)

assert url == (
EXPECTED_JOB_RUN_OP_EXTRA_LINK.format(
Expand Down
Loading
Loading