diff --git a/airflow-core/src/airflow/api_fastapi/common/db/task_instance.py b/airflow-core/src/airflow/api_fastapi/common/db/task_instance.py deleted file mode 100644 index b6748bdd61a4c..0000000000000 --- a/airflow-core/src/airflow/api_fastapi/common/db/task_instance.py +++ /dev/null @@ -1,57 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from pydantic import PositiveInt -from sqlalchemy.orm import joinedload -from sqlalchemy.sql import select - -from airflow.api_fastapi.common.db.common import SessionDep -from airflow.models import TaskInstance, Trigger -from airflow.models.taskinstancehistory import TaskInstanceHistory - - -def get_task_instance_or_history_for_try_number( - dag_id: str, - dag_run_id: str, - task_id: str, - try_number: PositiveInt, - session: SessionDep, - map_index: int, -) -> TaskInstance | TaskInstanceHistory: - query = ( - select(TaskInstance) - .where( - TaskInstance.task_id == task_id, - TaskInstance.dag_id == dag_id, - TaskInstance.run_id == dag_run_id, - TaskInstance.map_index == map_index, - ) - .join(TaskInstance.dag_run) - .options(joinedload(TaskInstance.trigger).joinedload(Trigger.triggerer_job)) - ) - ti = session.scalar(query) - if ti is None or ti.try_number != try_number: - query = select(TaskInstanceHistory).where( - TaskInstanceHistory.task_id == task_id, - TaskInstanceHistory.dag_id == dag_id, - TaskInstanceHistory.run_id == dag_run_id, - TaskInstanceHistory.map_index == map_index, - TaskInstanceHistory.try_number == try_number, - ) - ti = session.scalar(query) - return ti diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py index d7b7914e130bc..282e7b8f945a3 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py @@ -28,7 +28,6 @@ from airflow.api_fastapi.common.dagbag import DagBagDep from airflow.api_fastapi.common.db.common import SessionDep -from airflow.api_fastapi.common.db.task_instance import get_task_instance_or_history_for_try_number from airflow.api_fastapi.common.headers import HeaderAcceptJsonOrNdjson from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.common.types import Mimetype @@ -36,7 +35,8 @@ from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc from airflow.api_fastapi.core_api.security import DagAccessEntity, requires_access_dag from airflow.exceptions import TaskNotFound -from airflow.models import TaskInstance +from airflow.models import TaskInstance, Trigger +from airflow.models.taskinstancehistory import TaskInstanceHistory from airflow.utils.log.log_reader import TaskLogReader task_instances_log_router = AirflowRouter( @@ -105,14 +105,28 @@ def get_log( if not task_log_reader.supports_read: raise HTTPException(status.HTTP_400_BAD_REQUEST, "Task log handler does not support read logs.") - ti = get_task_instance_or_history_for_try_number( - dag_id=dag_id, - dag_run_id=dag_run_id, - task_id=task_id, - try_number=try_number, - session=session, - map_index=map_index, + query = ( + select(TaskInstance) + .where( + TaskInstance.task_id == task_id, + TaskInstance.dag_id == dag_id, + TaskInstance.run_id == dag_run_id, + TaskInstance.map_index == map_index, + ) + .join(TaskInstance.dag_run) + .options(joinedload(TaskInstance.trigger).joinedload(Trigger.triggerer_job)) + .options(joinedload(TaskInstance.dag_model)) ) + ti = session.scalar(query) + if ti is None: + query = select(TaskInstanceHistory).where( + TaskInstanceHistory.task_id == task_id, + TaskInstanceHistory.dag_id == dag_id, + TaskInstanceHistory.run_id == dag_run_id, + TaskInstanceHistory.map_index == map_index, + TaskInstanceHistory.try_number == try_number, + ) + ti = session.scalar(query) if ti is None: metadata["end_of_log"] = True diff --git a/airflow-core/src/airflow/utils/log/file_task_handler.py b/airflow-core/src/airflow/utils/log/file_task_handler.py index 8d6b47953c424..c39792baa81c9 100644 --- a/airflow-core/src/airflow/utils/log/file_task_handler.py +++ b/airflow-core/src/airflow/utils/log/file_task_handler.py @@ -586,9 +586,7 @@ def _read_from_logs_server(self, ti, worker_log_rel_path) -> tuple[LogSourceInfo sources = [] logs = [] try: - log_type = ( - LogType.TRIGGER if hasattr(ti, "triggerer_job") and ti.triggerer_job else LogType.WORKER - ) + log_type = LogType.TRIGGER if ti.triggerer_job else LogType.WORKER url, rel_path = self._get_log_retrieval_url(ti, worker_log_rel_path, log_type=log_type) response = _fetch_logs_from_service(url, rel_path) if response.status_code == 403: diff --git a/airflow-core/tests/unit/api_fastapi/common/db/__init__.py b/airflow-core/tests/unit/api_fastapi/common/db/__init__.py deleted file mode 100644 index 13a83393a9124..0000000000000 --- a/airflow-core/tests/unit/api_fastapi/common/db/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/airflow-core/tests/unit/api_fastapi/common/db/test_task_instance.py b/airflow-core/tests/unit/api_fastapi/common/db/test_task_instance.py deleted file mode 100644 index e5ed829dc8f4b..0000000000000 --- a/airflow-core/tests/unit/api_fastapi/common/db/test_task_instance.py +++ /dev/null @@ -1,79 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import annotations - -import pytest - -from airflow.api_fastapi.common.db.task_instance import get_task_instance_or_history_for_try_number -from airflow.models.taskinstance import TaskInstance -from airflow.models.taskinstancehistory import TaskInstanceHistory -from airflow.providers.standard.operators.empty import EmptyOperator -from airflow.utils import timezone -from airflow.utils.types import DagRunType - -from tests_common.test_utils.db import clear_db_runs - -pytestmark = pytest.mark.db_test - - -class TestDBTaskInstance: - DAG_ID = "dag_for_testing_db_task_instance" - RUN_ID = "dag_run_id_for_testing_db_task_instance" - TASK_ID = "task_for_testing_db_task_instance" - TRY_NUMBER = 1 - - default_time = "2020-06-10T20:00:00+00:00" - - @pytest.fixture(autouse=True) - def setup_attrs(self, dag_maker, session) -> None: - with dag_maker(self.DAG_ID, start_date=timezone.parse(self.default_time), session=session) as dag: - EmptyOperator(task_id=self.TASK_ID) - - dr = dag_maker.create_dagrun( - run_id=self.RUN_ID, - run_type=DagRunType.SCHEDULED, - logical_date=timezone.parse(self.default_time), - start_date=timezone.parse(self.default_time), - ) - - for ti in dr.task_instances: - ti.try_number = 1 - ti.hostname = "localhost" - session.merge(ti) - dag.clear() - for ti in dr.task_instances: - ti.try_number = 2 - ti.hostname = "localhost" - session.merge(ti) - session.commit() - - def teardown_method(self): - clear_db_runs() - - @pytest.mark.parametrize("try_number", [1, 2]) - def test_get_task_instance_or_history_for_try_number(self, try_number, session): - ti = get_task_instance_or_history_for_try_number( - self.DAG_ID, - self.RUN_ID, - self.TASK_ID, - try_number, - session=session, - map_index=-1, - ) - - assert isinstance(ti, TaskInstanceHistory) if try_number == 1 else TaskInstance