From abb792268cd0cef624404b8ebe26f1746a95ec2a Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Sat, 14 Dec 2024 21:48:17 +0000 Subject: [PATCH 1/3] Fix task_id validation in baseoperator --- airflow/models/baseoperator.py | 3 ++- tests/models/test_baseoperator.py | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 773552184f103..3c86226c09ce5 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -965,12 +965,13 @@ def __init__( category=RemovedInAirflow3Warning, stacklevel=3, ) - validate_key(task_id) + dag = dag or DagContext.get_current_dag() task_group = task_group or TaskGroupContext.get_current_task_group(dag) self.task_id = task_group.child_id(task_id) if task_group else task_id + validate_key(self.task_id) if not self.__from_mapped and task_group: task_group.add(self) diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 48aaf2699b918..65da48d046a10 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -564,6 +564,14 @@ def test_chain(self): assert [op2] == tgop3.get_direct_relatives(upstream=False) assert [op2] == tgop4.get_direct_relatives(upstream=False) + def test_baseoperator_raises_exception_when_task_id_invalid(self): + """Test exception is raised when operator task id + taskgroup id > 250 chars.""" + dag = DAG(dag_id="foo", schedule=None, start_date=datetime.now()) + + with pytest.raises(AirflowException, match="The key has to be less than 250 characters"): + tg1 = TaskGroup("A" * 20, dag=dag) + BaseOperator(task_id="1" * 250, task_group=tg1, dag=dag) + def test_chain_linear(self): dag = DAG(dag_id="test_chain_linear", schedule=None, start_date=datetime.now()) From 1a5946df0a2365371c5fe1c297760d8e3c94b61a Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Sat, 14 Dec 2024 21:55:13 +0000 Subject: [PATCH 2/3] Fix task_id validation in baseoperator --- airflow/models/baseoperator.py | 3 ++- tests/models/test_baseoperator.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 3c86226c09ce5..65900276271a6 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -966,12 +966,13 @@ def __init__( stacklevel=3, ) - dag = dag or DagContext.get_current_dag() task_group = task_group or TaskGroupContext.get_current_task_group(dag) self.task_id = task_group.child_id(task_id) if task_group else task_id + validate_key(self.task_id) + if not self.__from_mapped and task_group: task_group.add(self) diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 65da48d046a10..c4aeec08b7d23 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -568,8 +568,8 @@ def test_baseoperator_raises_exception_when_task_id_invalid(self): """Test exception is raised when operator task id + taskgroup id > 250 chars.""" dag = DAG(dag_id="foo", schedule=None, start_date=datetime.now()) + tg1 = TaskGroup("A" * 20, dag=dag) with pytest.raises(AirflowException, match="The key has to be less than 250 characters"): - tg1 = TaskGroup("A" * 20, dag=dag) BaseOperator(task_id="1" * 250, task_group=tg1, dag=dag) def test_chain_linear(self): From 94b103edc0b0e3f00624a5fb97ddc0cbf2158a49 Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Sun, 15 Dec 2024 00:32:31 +0000 Subject: [PATCH 3/3] add additional tests to check task id length --- tests/models/test_baseoperator.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index c4aeec08b7d23..8ce9ca195e913 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -564,7 +564,7 @@ def test_chain(self): assert [op2] == tgop3.get_direct_relatives(upstream=False) assert [op2] == tgop4.get_direct_relatives(upstream=False) - def test_baseoperator_raises_exception_when_task_id_invalid(self): + def test_baseoperator_raises_exception_when_task_id_plus_taskgroup_id_exceeds_250_chars(self): """Test exception is raised when operator task id + taskgroup id > 250 chars.""" dag = DAG(dag_id="foo", schedule=None, start_date=datetime.now()) @@ -572,6 +572,25 @@ def test_baseoperator_raises_exception_when_task_id_invalid(self): with pytest.raises(AirflowException, match="The key has to be less than 250 characters"): BaseOperator(task_id="1" * 250, task_group=tg1, dag=dag) + def test_baseoperator_with_task_id_and_taskgroup_id_less_than_250_chars(self): + """Test exception is not raised when operator task id + taskgroup id < 250 chars.""" + dag = DAG(dag_id="foo", schedule=None, start_date=datetime.now()) + + tg1 = TaskGroup("A" * 10, dag=dag) + try: + BaseOperator(task_id="1" * 239, task_group=tg1, dag=dag) + except Exception as e: + pytest.fail(f"Exception raised: {e}") + + def test_baseoperator_with_task_id_less_than_250_chars(self): + """Test exception is not raised when operator task id < 250 chars.""" + dag = DAG(dag_id="foo", schedule=None, start_date=datetime.now()) + + try: + BaseOperator(task_id="1" * 249, dag=dag) + except Exception as e: + pytest.fail(f"Exception raised: {e}") + def test_chain_linear(self): dag = DAG(dag_id="test_chain_linear", schedule=None, start_date=datetime.now())