From bda5627bd45c877ea0a18b5a3d6ac76ac325ba41 Mon Sep 17 00:00:00 2001 From: phi Date: Wed, 28 Aug 2024 21:33:57 +0900 Subject: [PATCH 01/12] fix: add skip_if, run_if --- airflow/utils/decorators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py index 4cad5ab9e6073..e299999423e56 100644 --- a/airflow/utils/decorators.py +++ b/airflow/utils/decorators.py @@ -49,7 +49,7 @@ def _remove_task_decorator(py_source, decorator_name): after_decorator = after_decorator[1:] return before_decorator + after_decorator - decorators = ["@setup", "@teardown", task_decorator_name] + decorators = ["@setup", "@teardown", "@task.skip_if", "@task.run_if", task_decorator_name] for decorator in decorators: python_source = _remove_task_decorator(python_source, decorator) return python_source From dc8efb4b8cec625a2da1dd713bec3afd07705bee Mon Sep 17 00:00:00 2001 From: phi Date: Wed, 28 Aug 2024 21:37:52 +0900 Subject: [PATCH 02/12] tests: skip_if, run_if --- tests/utils/test_decorators.py | 41 ++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 tests/utils/test_decorators.py diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py new file mode 100644 index 0000000000000..7a6857c4d8677 --- /dev/null +++ b/tests/utils/test_decorators.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.decorators import task + + +def test_remove_skip_if_decorator(): + @task.skip_if(lambda context: True) + @task.virtualenv() + def f(): ... + + xcom_arg = f() + source = xcom_arg.operator.get_python_source() + + assert "skip_if" not in source + + +def test_remove_run_if_decorator(): + @task.run_if(lambda context: True) + @task.virtualenv() + def f(): ... + + xcom_arg = f() + source = xcom_arg.operator.get_python_source() + + assert "run_if" not in source From b1aaa66c233d7ec6583e22dfefa112ab571afdf1 Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 1 Sep 2024 21:47:24 +0900 Subject: [PATCH 03/12] test --- tests/utils/test_decorators.py | 53 +++++++++++++++++++++++++--------- 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index 7a6857c4d8677..97ceef3ce2e29 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -16,26 +16,51 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING + from airflow.decorators import task +if TYPE_CHECKING: + from airflow.decorators.base import Task + + +class TestDecoratorSource: + def parse_python_source(self, task: Task) -> str: + return task().operator.get_python_source() + + def test_branch_external_python(self): + @task.branch_virtualenv() + def f(): + return ["some_task"] + + assert self.parse_python_source(f) == 'def f():\n return ["some_task"]\n' + + def test_branch_virtualenv(self): + @task.external_python(python="python3") + def f(): + return "hello world" -def test_remove_skip_if_decorator(): - @task.skip_if(lambda context: True) - @task.virtualenv() - def f(): ... + assert self.parse_python_source(f) == 'def f():\n return "hello world"\n' - xcom_arg = f() - source = xcom_arg.operator.get_python_source() + def test_virtualenv(self): + @task.virtualenv() + def f(): + return "hello world" - assert "skip_if" not in source + assert self.parse_python_source(f) == 'def f():\n return "hello world"\n' + def test_skip_if(self): + @task.skip_if(lambda context: True) + @task.virtualenv() + def f(): + return "hello world" -def test_remove_run_if_decorator(): - @task.run_if(lambda context: True) - @task.virtualenv() - def f(): ... + assert self.parse_python_source(f) == 'def f():\n return "hello world"\n' - xcom_arg = f() - source = xcom_arg.operator.get_python_source() + def test_run_if(self): + @task.run_if(lambda context: True) + @task.virtualenv() + def f(): + return "hello world" - assert "run_if" not in source + assert self.parse_python_source(f) == 'def f():\n return "hello world"\n' From f1103959c46719f04e1990437c93d9daaa1febc2 Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 1 Sep 2024 22:36:44 +0900 Subject: [PATCH 04/12] more test --- tests/utils/test_decorators.py | 101 +++++++++++++++++++++++++++++++-- 1 file changed, 97 insertions(+), 4 deletions(-) diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index 97ceef3ce2e29..9f8bb83b0481f 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -16,23 +16,52 @@ # under the License. from __future__ import annotations +import ast from typing import TYPE_CHECKING +import pytest + from airflow.decorators import task if TYPE_CHECKING: from airflow.decorators.base import Task +DECORATORS = set(x for x in dir(task) if not x.startswith("_")) - {"skip_if", "run_if"} + class TestDecoratorSource: - def parse_python_source(self, task: Task) -> str: - return task().operator.get_python_source() + @staticmethod + def parse_python_source(task: Task) -> str: + operator = task().operator + public_methods = {x for x in dir(operator) if not x.startswith("_")} + if "get_python_source" not in public_methods: + pytest.skip(f"Operator {operator} does not have get_python_source method") + return operator.get_python_source() + + @staticmethod + def init_decorator(decorator_name: str): + decorator_factory = getattr(task, decorator_name) + + kwargs = {} + if "external" in decorator_name: + kwargs["python"] = "python3" + return decorator_factory(**kwargs) + + @classmethod + def parse_decorator_names(cls, source: Task | str) -> list[str]: + if not isinstance(source, str): + source = cls.parse_python_source(source) + node = ast.parse(source) + func: ast.FunctionDef = node.body[0] + decorators: list[ast.Name] = func.decorator_list + return [decorator.id for decorator in decorators] def test_branch_external_python(self): @task.branch_virtualenv() def f(): return ["some_task"] + assert not self.parse_decorator_names(f) assert self.parse_python_source(f) == 'def f():\n return ["some_task"]\n' def test_branch_virtualenv(self): @@ -40,6 +69,7 @@ def test_branch_virtualenv(self): def f(): return "hello world" + assert not self.parse_decorator_names(f) assert self.parse_python_source(f) == 'def f():\n return "hello world"\n' def test_virtualenv(self): @@ -47,20 +77,83 @@ def test_virtualenv(self): def f(): return "hello world" + assert not self.parse_decorator_names(f) assert self.parse_python_source(f) == 'def f():\n return "hello world"\n' - def test_skip_if(self): + @pytest.mark.parametrize("decorator_name", DECORATORS) + def test_skip_if(self, decorator_name): + decorator = self.init_decorator(decorator_name) + @task.skip_if(lambda context: True) + @decorator + def f(): + return "hello world" + + source = self.parse_python_source(f) + decorators = self.parse_decorator_names(source) + + # In `airflow.utils.decorators.remove_task_decorator`, `@decorator` should be removed, + # but it does so using `custom_operator_name`, which is defined as a string, + # we have to check it ourselves during testing. + assert len(decorators) == 1 + assert decorators[0] == "decorator" + + @pytest.mark.parametrize("decorator_name", DECORATORS) + def test_run_if(self, decorator_name): + decorator = self.init_decorator(decorator_name) + + @task.run_if(lambda context: True) + @decorator + def f(): + return "hello world" + + source = self.parse_python_source(f) + + # In `airflow.utils.decorators.remove_task_decorator`, `@decorator` should be removed, + # but it does so using `custom_operator_name`, which is defined as a string, + # we have to check it ourselves during testing. + decorators = self.parse_decorator_names(source) + assert len(decorators) == 1 + assert decorators[0] == "decorator" + + def test_skip_if_and_run_if(self): + @task.skip_if(lambda context: True) + @task.run_if(lambda context: True) @task.virtualenv() def f(): return "hello world" assert self.parse_python_source(f) == 'def f():\n return "hello world"\n' - def test_run_if(self): + def test_run_if_and_skip_if(self): @task.run_if(lambda context: True) + @task.skip_if(lambda context: True) @task.virtualenv() def f(): return "hello world" assert self.parse_python_source(f) == 'def f():\n return "hello world"\n' + + def test_skip_if_allow_decorator(self): + def decorator(func): + return func + + @task.skip_if(lambda context: True) + @task.virtualenv() + @decorator + def f(): + return "hello world" + + assert self.parse_decorator_names(f) == ["decorator"] + + def test_run_if_allow_decorator(self): + def decorator(func): + return func + + @task.run_if(lambda context: True) + @task.virtualenv() + @decorator + def f(): + return "hello world" + + assert self.parse_decorator_names(f) == ["decorator"] From ffa04166dd43288c77b234c627ba1f8273faf0d0 Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 1 Sep 2024 23:04:16 +0900 Subject: [PATCH 05/12] fix: mypy errors --- tests/utils/test_decorators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index 9f8bb83b0481f..41c8bf42d978b 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -52,8 +52,8 @@ def parse_decorator_names(cls, source: Task | str) -> list[str]: if not isinstance(source, str): source = cls.parse_python_source(source) node = ast.parse(source) - func: ast.FunctionDef = node.body[0] - decorators: list[ast.Name] = func.decorator_list + func: ast.FunctionDef = node.body[0] # type: ignore[assignment] + decorators: list[ast.Name] = func.decorator_list # type: ignore[assignment] return [decorator.id for decorator in decorators] def test_branch_external_python(self): From cb5bbcaa4e668f5a89b1203f9f7647c3afe37347 Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 1 Sep 2024 23:27:44 +0900 Subject: [PATCH 06/12] tests --- tests/utils/test_decorators.py | 45 +++++++++++++++++----------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index 41c8bf42d978b..a0385742742e9 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -17,7 +17,7 @@ from __future__ import annotations import ast -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pytest @@ -31,12 +31,11 @@ class TestDecoratorSource: @staticmethod - def parse_python_source(task: Task) -> str: - operator = task().operator - public_methods = {x for x in dir(operator) if not x.startswith("_")} - if "get_python_source" not in public_methods: - pytest.skip(f"Operator {operator} does not have get_python_source method") - return operator.get_python_source() + def update_custom_operator_name(operator: Any, custom_operator_name: str): + custom_operator_name = ( + custom_operator_name if custom_operator_name.startswith("@") else f"@{custom_operator_name}" + ) + operator.__dict__["custom_operator_name"] = custom_operator_name @staticmethod def init_decorator(decorator_name: str): @@ -47,6 +46,16 @@ def init_decorator(decorator_name: str): kwargs["python"] = "python3" return decorator_factory(**kwargs) + @classmethod + def parse_python_source(cls, task: Task, custom_operator_name: str | None = None) -> str: + operator = task().operator + public_methods = {x for x in dir(operator) if not x.startswith("_")} + if "get_python_source" not in public_methods: + pytest.skip(f"Operator {operator} does not have get_python_source method") + if custom_operator_name: + cls.update_custom_operator_name(operator, custom_operator_name) + return operator.get_python_source() + @classmethod def parse_decorator_names(cls, source: Task | str) -> list[str]: if not isinstance(source, str): @@ -89,14 +98,8 @@ def test_skip_if(self, decorator_name): def f(): return "hello world" - source = self.parse_python_source(f) - decorators = self.parse_decorator_names(source) - - # In `airflow.utils.decorators.remove_task_decorator`, `@decorator` should be removed, - # but it does so using `custom_operator_name`, which is defined as a string, - # we have to check it ourselves during testing. - assert len(decorators) == 1 - assert decorators[0] == "decorator" + source = self.parse_python_source(f, "decorator") + assert source == 'def f():\n return "hello world"\n' @pytest.mark.parametrize("decorator_name", DECORATORS) def test_run_if(self, decorator_name): @@ -107,14 +110,8 @@ def test_run_if(self, decorator_name): def f(): return "hello world" - source = self.parse_python_source(f) - - # In `airflow.utils.decorators.remove_task_decorator`, `@decorator` should be removed, - # but it does so using `custom_operator_name`, which is defined as a string, - # we have to check it ourselves during testing. - decorators = self.parse_decorator_names(source) - assert len(decorators) == 1 - assert decorators[0] == "decorator" + source = self.parse_python_source(f, "decorator") + assert source == 'def f():\n return "hello world"\n' def test_skip_if_and_run_if(self): @task.skip_if(lambda context: True) @@ -145,6 +142,7 @@ def f(): return "hello world" assert self.parse_decorator_names(f) == ["decorator"] + assert self.parse_python_source(f) == '@decorator\ndef f():\n return "hello world"\n' def test_run_if_allow_decorator(self): def decorator(func): @@ -157,3 +155,4 @@ def f(): return "hello world" assert self.parse_decorator_names(f) == ["decorator"] + assert self.parse_python_source(f) == '@decorator\ndef f():\n return "hello world"\n' From 6c12870b46491c470a2ac34cae5bccdc58284c22 Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 1 Sep 2024 23:56:19 +0900 Subject: [PATCH 07/12] fix: xdist --- tests/utils/test_decorators.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index a0385742742e9..1daa29dd17897 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -25,8 +25,7 @@ if TYPE_CHECKING: from airflow.decorators.base import Task - -DECORATORS = set(x for x in dir(task) if not x.startswith("_")) - {"skip_if", "run_if"} +DECORATORS = tuple(set(x for x in dir(task) if not x.startswith("_")) - {"skip_if", "run_if"}) class TestDecoratorSource: From 13c3f6d5a94a8771c68da2e420fffbac7a0523bf Mon Sep 17 00:00:00 2001 From: phi Date: Mon, 2 Sep 2024 00:10:36 +0900 Subject: [PATCH 08/12] fix: xdist --- tests/utils/test_decorators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index 1daa29dd17897..13657a993d389 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -25,7 +25,7 @@ if TYPE_CHECKING: from airflow.decorators.base import Task -DECORATORS = tuple(set(x for x in dir(task) if not x.startswith("_")) - {"skip_if", "run_if"}) +DECORATORS = sorted(set(x for x in dir(task) if not x.startswith("_")) - {"skip_if", "run_if"}) class TestDecoratorSource: From 753202ba4ae33de2ca4cfad5d70c8a7cf7fc49d5 Mon Sep 17 00:00:00 2001 From: phi Date: Mon, 2 Sep 2024 00:29:15 +0900 Subject: [PATCH 09/12] fix: use hasattr --- tests/utils/test_decorators.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index 13657a993d389..4752885c40b3b 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -48,8 +48,7 @@ def init_decorator(decorator_name: str): @classmethod def parse_python_source(cls, task: Task, custom_operator_name: str | None = None) -> str: operator = task().operator - public_methods = {x for x in dir(operator) if not x.startswith("_")} - if "get_python_source" not in public_methods: + if not hasattr(operator, "get_python_source"): pytest.skip(f"Operator {operator} does not have get_python_source method") if custom_operator_name: cls.update_custom_operator_name(operator, custom_operator_name) From f21d8d602ec448d642ebd1dc0b2465d137127ab0 Mon Sep 17 00:00:00 2001 From: phi Date: Wed, 4 Sep 2024 20:42:38 +0900 Subject: [PATCH 10/12] fix: use fixture --- tests/utils/test_decorators.py | 229 +++++++++++++++------------------ 1 file changed, 101 insertions(+), 128 deletions(-) diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index 4752885c40b3b..6e37bd6ce0938 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import ast from typing import TYPE_CHECKING, Any import pytest @@ -24,133 +23,107 @@ from airflow.decorators import task if TYPE_CHECKING: - from airflow.decorators.base import Task + from airflow.decorators.base import Task, TaskDecorator DECORATORS = sorted(set(x for x in dir(task) if not x.startswith("_")) - {"skip_if", "run_if"}) +DECORATORS_USING_SOURCE = ("external_python", "virtualenv", "branch_virtualenv", "branch_external_python") -class TestDecoratorSource: - @staticmethod - def update_custom_operator_name(operator: Any, custom_operator_name: str): - custom_operator_name = ( - custom_operator_name if custom_operator_name.startswith("@") else f"@{custom_operator_name}" - ) - operator.__dict__["custom_operator_name"] = custom_operator_name - - @staticmethod - def init_decorator(decorator_name: str): - decorator_factory = getattr(task, decorator_name) - - kwargs = {} - if "external" in decorator_name: - kwargs["python"] = "python3" - return decorator_factory(**kwargs) - - @classmethod - def parse_python_source(cls, task: Task, custom_operator_name: str | None = None) -> str: - operator = task().operator - if not hasattr(operator, "get_python_source"): - pytest.skip(f"Operator {operator} does not have get_python_source method") - if custom_operator_name: - cls.update_custom_operator_name(operator, custom_operator_name) - return operator.get_python_source() - - @classmethod - def parse_decorator_names(cls, source: Task | str) -> list[str]: - if not isinstance(source, str): - source = cls.parse_python_source(source) - node = ast.parse(source) - func: ast.FunctionDef = node.body[0] # type: ignore[assignment] - decorators: list[ast.Name] = func.decorator_list # type: ignore[assignment] - return [decorator.id for decorator in decorators] - - def test_branch_external_python(self): - @task.branch_virtualenv() - def f(): - return ["some_task"] - - assert not self.parse_decorator_names(f) - assert self.parse_python_source(f) == 'def f():\n return ["some_task"]\n' - - def test_branch_virtualenv(self): - @task.external_python(python="python3") - def f(): - return "hello world" - - assert not self.parse_decorator_names(f) - assert self.parse_python_source(f) == 'def f():\n return "hello world"\n' - - def test_virtualenv(self): - @task.virtualenv() - def f(): - return "hello world" - - assert not self.parse_decorator_names(f) - assert self.parse_python_source(f) == 'def f():\n return "hello world"\n' - - @pytest.mark.parametrize("decorator_name", DECORATORS) - def test_skip_if(self, decorator_name): - decorator = self.init_decorator(decorator_name) - - @task.skip_if(lambda context: True) - @decorator - def f(): - return "hello world" - - source = self.parse_python_source(f, "decorator") - assert source == 'def f():\n return "hello world"\n' - - @pytest.mark.parametrize("decorator_name", DECORATORS) - def test_run_if(self, decorator_name): - decorator = self.init_decorator(decorator_name) - - @task.run_if(lambda context: True) - @decorator - def f(): - return "hello world" - - source = self.parse_python_source(f, "decorator") - assert source == 'def f():\n return "hello world"\n' - - def test_skip_if_and_run_if(self): - @task.skip_if(lambda context: True) - @task.run_if(lambda context: True) - @task.virtualenv() - def f(): - return "hello world" - - assert self.parse_python_source(f) == 'def f():\n return "hello world"\n' - - def test_run_if_and_skip_if(self): - @task.run_if(lambda context: True) - @task.skip_if(lambda context: True) - @task.virtualenv() - def f(): - return "hello world" - - assert self.parse_python_source(f) == 'def f():\n return "hello world"\n' - - def test_skip_if_allow_decorator(self): - def decorator(func): - return func - - @task.skip_if(lambda context: True) - @task.virtualenv() - @decorator - def f(): - return "hello world" - - assert self.parse_decorator_names(f) == ["decorator"] - assert self.parse_python_source(f) == '@decorator\ndef f():\n return "hello world"\n' - - def test_run_if_allow_decorator(self): - def decorator(func): - return func - - @task.run_if(lambda context: True) - @task.virtualenv() - @decorator - def f(): - return "hello world" - - assert self.parse_decorator_names(f) == ["decorator"] - assert self.parse_python_source(f) == '@decorator\ndef f():\n return "hello world"\n' +@pytest.fixture +def decorator(request: pytest.FixtureRequest) -> TaskDecorator: + decorator_factory = getattr(task, request.param) + + kwargs = {} + if "external" in request.param: + kwargs["python"] = "python3" + return decorator_factory(**kwargs) + + +@pytest.mark.parametrize("decorator", DECORATORS_USING_SOURCE, indirect=["decorator"]) +def test_task_decorator_using_source(decorator: TaskDecorator): + @decorator + def f(): + return ["some_task"] + + assert parse_python_source(f, "decorator") == 'def f():\n return ["some_task"]\n' + + +@pytest.mark.parametrize("decorator", DECORATORS, indirect=["decorator"]) +def test_skip_if(decorator: TaskDecorator): + @task.skip_if(lambda context: True) + @decorator + def f(): + return "hello world" + + assert parse_python_source(f, "decorator") == 'def f():\n return "hello world"\n' + + +@pytest.mark.parametrize("decorator", DECORATORS, indirect=["decorator"]) +def test_run_if(decorator: TaskDecorator): + @task.run_if(lambda context: True) + @decorator + def f(): + return "hello world" + + assert parse_python_source(f, "decorator") == 'def f():\n return "hello world"\n' + + +def test_skip_if_and_run_if(): + @task.skip_if(lambda context: True) + @task.run_if(lambda context: True) + @task.virtualenv() + def f(): + return "hello world" + + assert parse_python_source(f) == 'def f():\n return "hello world"\n' + + +def test_run_if_and_skip_if(): + @task.run_if(lambda context: True) + @task.skip_if(lambda context: True) + @task.virtualenv() + def f(): + return "hello world" + + assert parse_python_source(f) == 'def f():\n return "hello world"\n' + + +def test_skip_if_allow_decorator(): + def non_task_decorator(func): + return func + + @task.skip_if(lambda context: True) + @task.virtualenv() + @non_task_decorator + def f(): + return "hello world" + + assert parse_python_source(f) == '@non_task_decorator\ndef f():\n return "hello world"\n' + + +def test_run_if_allow_decorator(): + def non_task_decorator(func): + return func + + @task.run_if(lambda context: True) + @task.virtualenv() + @non_task_decorator + def f(): + return "hello world" + + assert parse_python_source(f) == '@non_task_decorator\ndef f():\n return "hello world"\n' + + +def parse_python_source(task: Task, custom_operator_name: str | None = None) -> str: + operator = task().operator + if not hasattr(operator, "get_python_source"): + pytest.skip(f"Operator {operator} does not have get_python_source method") + if custom_operator_name: + update_custom_operator_name(operator, custom_operator_name) + return operator.get_python_source() + + +def update_custom_operator_name(operator: Any, custom_operator_name: str): + custom_operator_name = ( + custom_operator_name if custom_operator_name.startswith("@") else f"@{custom_operator_name}" + ) + operator.__dict__["custom_operator_name"] = custom_operator_name From b13c664778247028fb1ab1a7aa745a471b8740be Mon Sep 17 00:00:00 2001 From: phi Date: Wed, 4 Sep 2024 20:44:53 +0900 Subject: [PATCH 11/12] fix: rm update_custom_operator_name --- tests/utils/test_decorators.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index 6e37bd6ce0938..82c04bcd535b9 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import pytest @@ -117,13 +117,10 @@ def parse_python_source(task: Task, custom_operator_name: str | None = None) -> operator = task().operator if not hasattr(operator, "get_python_source"): pytest.skip(f"Operator {operator} does not have get_python_source method") + if custom_operator_name: - update_custom_operator_name(operator, custom_operator_name) + custom_operator_name = ( + custom_operator_name if custom_operator_name.startswith("@") else f"@{custom_operator_name}" + ) + operator.__dict__["custom_operator_name"] = custom_operator_name return operator.get_python_source() - - -def update_custom_operator_name(operator: Any, custom_operator_name: str): - custom_operator_name = ( - custom_operator_name if custom_operator_name.startswith("@") else f"@{custom_operator_name}" - ) - operator.__dict__["custom_operator_name"] = custom_operator_name From 98d619d3c5fb01effee3c68e15d45b53c2e9959a Mon Sep 17 00:00:00 2001 From: phi Date: Fri, 6 Sep 2024 22:13:42 +0900 Subject: [PATCH 12/12] fix: rm check get_python_source --- tests/utils/test_decorators.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index 82c04bcd535b9..19d3ec31d0311 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -24,7 +24,12 @@ if TYPE_CHECKING: from airflow.decorators.base import Task, TaskDecorator -DECORATORS = sorted(set(x for x in dir(task) if not x.startswith("_")) - {"skip_if", "run_if"}) + +_CONDITION_DECORATORS = frozenset({"skip_if", "run_if"}) +_NO_SOURCE_DECORATORS = frozenset({"sensor"}) +DECORATORS = sorted( + set(x for x in dir(task) if not x.startswith("_")) - _CONDITION_DECORATORS - _NO_SOURCE_DECORATORS +) DECORATORS_USING_SOURCE = ("external_python", "virtualenv", "branch_virtualenv", "branch_external_python") @@ -115,9 +120,6 @@ def f(): def parse_python_source(task: Task, custom_operator_name: str | None = None) -> str: operator = task().operator - if not hasattr(operator, "get_python_source"): - pytest.skip(f"Operator {operator} does not have get_python_source method") - if custom_operator_name: custom_operator_name = ( custom_operator_name if custom_operator_name.startswith("@") else f"@{custom_operator_name}"