Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@
ValidationSource = Union[Literal["expand"], Literal["partial"]]


def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None:
def extract_unknown_args(
op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]
) -> dict[str, Any]:
# use a dict so order of args is same as code order
unknown_args = value.copy()
for klass in op.mro():
Expand All @@ -88,21 +90,29 @@ def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, valu
param_names = init._BaseOperatorMeta__param_names
except AttributeError:
continue

for name in param_names:
value = unknown_args.pop(name, NOTSET)
unknown_arg = unknown_args.pop(name, NOTSET)
if func != "expand":
continue
if value is NOTSET:
if unknown_arg is NOTSET:
continue
if is_mappable(value):
if is_mappable(unknown_arg):
continue
type_name = type(value).__name__
type_name = type(unknown_arg).__name__
error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}"
raise ValueError(error)
if not unknown_args:
return # If we have no args left to check: stop looking at the MRO chain.
return {} # If we have no args left to check: stop looking at the MRO chain.
return unknown_args

if len(unknown_args) == 1:

def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None:
unknown_args = extract_unknown_args(op, func, value)

if not unknown_args:
return
elif len(unknown_args) == 1:
error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}"
else:
names = ", ".join(repr(n) for n in unknown_args)
Expand Down Expand Up @@ -145,6 +155,10 @@ class OperatorPartial:
def __attrs_post_init__(self):
from airflow.operators.subdag import SubDagOperator

unknown_args = extract_unknown_args(self.operator_class, "partial", self.kwargs)
for arg in unknown_args.keys():
del self.kwargs[arg]

if issubclass(self.operator_class, SubDagOperator):
raise TypeError("Mapping over deprecated SubDagOperator is not supported")
validate_mapping_kwargs(self.operator_class, "partial", self.kwargs)
Expand Down