From 997c73ebd19195830857e063ebe82241b4cdd86e Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Sat, 14 Dec 2024 22:14:13 +0000 Subject: [PATCH 1/3] Fix task id validation in BaseOperator --- task_sdk/src/airflow/sdk/definitions/baseoperator.py | 2 +- tests/models/test_baseoperator.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py index 9b1f64d970e64..aaaee72a25df1 100644 --- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py @@ -732,7 +732,7 @@ def __init__( f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). " f"Invalid arguments were:\n**kwargs: {kwargs}", ) - validate_key(task_id) + validate_key(self.task_id) self.owner = owner self.email = email diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index e95866d95a5e9..076fa17535959 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -395,6 +395,14 @@ def test_chain(self): assert [op2] == tgop3.get_direct_relatives(upstream=False) assert [op2] == tgop4.get_direct_relatives(upstream=False) + def test_baseoperator_raises_exception_when_task_id_invalid(self): + """Test exception is raised when operator task id + taskgroup id > 250 chars.""" + dag = DAG(dag_id="foo", schedule=None, start_date=datetime.now()) + + tg1 = TaskGroup("A" * 20, dag=dag) + with pytest.raises(ValueError, match="The key has to be less than 250 characters"): + BaseOperator(task_id="1" * 250, task_group=tg1, dag=dag) + def test_chain_linear(self): dag = DAG(dag_id="test_chain_linear", schedule=None, start_date=datetime.now()) From 14b0ff0f056bfaa161a01989e091ef28504f424c Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Sun, 15 Dec 2024 00:34:15 +0000 Subject: [PATCH 2/3] add additional tests to check task id length --- tests/models/test_baseoperator.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 076fa17535959..ad6b13d6a814b 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -395,14 +395,33 @@ def test_chain(self): assert [op2] == tgop3.get_direct_relatives(upstream=False) assert [op2] == tgop4.get_direct_relatives(upstream=False) - def test_baseoperator_raises_exception_when_task_id_invalid(self): + def test_baseoperator_raises_exception_when_task_id_plus_taskgroup_id_exceeds_250_chars(self): """Test exception is raised when operator task id + taskgroup id > 250 chars.""" dag = DAG(dag_id="foo", schedule=None, start_date=datetime.now()) tg1 = TaskGroup("A" * 20, dag=dag) - with pytest.raises(ValueError, match="The key has to be less than 250 characters"): + with pytest.raises(AirflowException, match="The key has to be less than 250 characters"): BaseOperator(task_id="1" * 250, task_group=tg1, dag=dag) + def test_baseoperator_with_task_id_and_taskgroup_id_less_than_250_chars(self): + """Test exception is not raised when operator task id + taskgroup id < 250 chars.""" + dag = DAG(dag_id="foo", schedule=None, start_date=datetime.now()) + + tg1 = TaskGroup("A" * 10, dag=dag) + try: + BaseOperator(task_id="1" * 239, task_group=tg1, dag=dag) + except Exception as e: + pytest.fail(f"Exception raised: {e}") + + def test_baseoperator_with_task_id_less_than_250_chars(self): + """Test exception is not raised when operator task id < 250 chars.""" + dag = DAG(dag_id="foo", schedule=None, start_date=datetime.now()) + + try: + BaseOperator(task_id="1" * 249, dag=dag) + except Exception as e: + pytest.fail(f"Exception raised: {e}") + def test_chain_linear(self): dag = DAG(dag_id="test_chain_linear", schedule=None, start_date=datetime.now()) From 89b0082f2eb7a7f11f50a55d3b9763128ff9b860 Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Sun, 15 Dec 2024 01:24:57 +0000 Subject: [PATCH 3/3] fix assert statement --- tests/models/test_baseoperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index ad6b13d6a814b..2c598edc777ac 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -400,7 +400,7 @@ def test_baseoperator_raises_exception_when_task_id_plus_taskgroup_id_exceeds_25 dag = DAG(dag_id="foo", schedule=None, start_date=datetime.now()) tg1 = TaskGroup("A" * 20, dag=dag) - with pytest.raises(AirflowException, match="The key has to be less than 250 characters"): + with pytest.raises(ValueError, match="The key has to be less than 250 characters"): BaseOperator(task_id="1" * 250, task_group=tg1, dag=dag) def test_baseoperator_with_task_id_and_taskgroup_id_less_than_250_chars(self):