From c2d0eeced785f03eb54d1b235f469f7b247ed161 Mon Sep 17 00:00:00 2001 From: Victor Chiapaikeo Date: Sun, 12 Mar 2023 15:08:48 -0400 Subject: [PATCH] Support default_args in dynamic task mapping --- airflow/models/mappedoperator.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 2cbd46cd57384..1c5b241108c73 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -79,7 +79,9 @@ ValidationSource = Union[Literal["expand"], Literal["partial"]] -def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: +def extract_unknown_args( + op: type[BaseOperator], func: ValidationSource, value: dict[str, Any] +) -> dict[str, Any]: # use a dict so order of args is same as code order unknown_args = value.copy() for klass in op.mro(): @@ -88,21 +90,29 @@ def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, valu param_names = init._BaseOperatorMeta__param_names except AttributeError: continue + for name in param_names: - value = unknown_args.pop(name, NOTSET) + unknown_arg = unknown_args.pop(name, NOTSET) if func != "expand": continue - if value is NOTSET: + if unknown_arg is NOTSET: continue - if is_mappable(value): + if is_mappable(unknown_arg): continue - type_name = type(value).__name__ + type_name = type(unknown_arg).__name__ error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}" raise ValueError(error) if not unknown_args: - return # If we have no args left to check: stop looking at the MRO chain. + return {} # If we have no args left to check: stop looking at the MRO chain. + return unknown_args - if len(unknown_args) == 1: + +def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: + unknown_args = extract_unknown_args(op, func, value) + + if not unknown_args: + return + elif len(unknown_args) == 1: error = f"an unexpected keyword argument {unknown_args.popitem()[0]!r}" else: names = ", ".join(repr(n) for n in unknown_args) @@ -145,6 +155,10 @@ class OperatorPartial: def __attrs_post_init__(self): from airflow.operators.subdag import SubDagOperator + unknown_args = extract_unknown_args(self.operator_class, "partial", self.kwargs) + for arg in unknown_args.keys(): + del self.kwargs[arg] + if issubclass(self.operator_class, SubDagOperator): raise TypeError("Mapping over deprecated SubDagOperator is not supported") validate_mapping_kwargs(self.operator_class, "partial", self.kwargs)