Skip to content

Commit ddf30db

Browse files
committed
reverting to TI due to unnecessary changes
1 parent f2db478 commit ddf30db

File tree

1 file changed

+26
-28
lines changed

1 file changed

+26
-28
lines changed

airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from airflow.api_fastapi.core_api.security import GetUserDep
4545
from airflow.api_fastapi.core_api.services.public.common import BulkService
4646
from airflow.listeners.listener import get_listener_manager
47-
from airflow.models.taskinstance import TaskInstance
47+
from airflow.models.taskinstance import TaskInstance as TI
4848
from airflow.serialization.serialized_objects import SerializedDAG
4949
from airflow.utils.state import TaskInstanceState
5050

@@ -60,23 +60,21 @@ def _patch_ti_validate_request(
6060
session: SessionDep,
6161
map_index: int | None = -1,
6262
update_mask: list[str] | None = Query(None),
63-
) -> tuple[SerializedDAG, list[TaskInstance], dict]:
63+
) -> tuple[SerializedDAG, list[TI], dict]:
6464
dag = get_latest_version_of_dag(dag_bag, dag_id, session)
6565
if not dag.has_task(task_id):
6666
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Task '{task_id}' not found in DAG '{dag_id}'")
6767

6868
query = (
69-
select(TaskInstance)
70-
.where(
71-
TaskInstance.dag_id == dag_id, TaskInstance.run_id == dag_run_id, TaskInstance.task_id == task_id
72-
)
73-
.join(TaskInstance.dag_run)
74-
.options(joinedload(TaskInstance.rendered_task_instance_fields))
69+
select(TI)
70+
.where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == task_id)
71+
.join(TI.dag_run)
72+
.options(joinedload(TI.rendered_task_instance_fields))
7573
)
7674
if map_index is not None:
77-
query = query.where(TaskInstance.map_index == map_index)
75+
query = query.where(TI.map_index == map_index)
7876
else:
79-
query = query.order_by(TaskInstance.map_index)
77+
query = query.order_by(TI.map_index)
8078

8179
tis = session.scalars(query).all()
8280

@@ -106,7 +104,7 @@ def _patch_task_instance_state(
106104
data: dict,
107105
session: Session,
108106
commit: bool,
109-
) -> list[TaskInstance]:
107+
) -> list[TI]:
110108
map_index = getattr(task_instance_body, "map_index", None)
111109
map_indexes = None if map_index is None else [map_index]
112110

@@ -148,7 +146,7 @@ def _patch_task_instance_state(
148146

149147
def _patch_task_instance_note(
150148
task_instance_body: BulkTaskInstanceBody | PatchTaskInstanceBody,
151-
tis: list[TaskInstance],
149+
tis: list[TI],
152150
user: GetUserDep,
153151
update_mask: list[str] | None = Query(None),
154152
) -> None:
@@ -183,17 +181,17 @@ def __init__(
183181

184182
def categorize_task_instances(
185183
self, task_keys: set[tuple[str, int]]
186-
) -> tuple[dict[tuple[str, int], TaskInstance], set[tuple[str, int]], set[tuple[str, int]]]:
184+
) -> tuple[dict[tuple[str, int], TI], set[tuple[str, int]], set[tuple[str, int]]]:
187185
"""
188186
Categorize the given task_ids into matched_task_keys and not_found_task_keys based on existing task_ids.
189187
190188
:param task_keys: set of task_keys (tuple of task_id and map_index)
191189
:return: tuple of (task_instances_map, matched_task_keys, not_found_task_keys)
192190
"""
193-
query = select(TaskInstance).where(
194-
TaskInstance.dag_id == self.dag_id,
195-
TaskInstance.run_id == self.dag_run_id,
196-
TaskInstance.task_id.in_([task_id for task_id, _ in task_keys]),
191+
query = select(TI).where(
192+
TI.dag_id == self.dag_id,
193+
TI.run_id == self.dag_run_id,
194+
TI.task_id.in_([task_id for task_id, _ in task_keys]),
197195
)
198196
task_instances = self.session.scalars(query).all()
199197
task_instances_map = {
@@ -221,7 +219,7 @@ def handle_bulk_update(
221219
self, action: BulkUpdateAction[BulkTaskInstanceBody], results: BulkActionResponse
222220
) -> TaskInstanceCollectionResponse:
223221
"""Bulk Update Task Instances."""
224-
all_updated_tis: list[TaskInstance] = []
222+
all_updated_tis: list[TI] = []
225223
to_update_task_keys = {
226224
(task_instance.task_id, task_instance.map_index if task_instance.map_index is not None else -1)
227225
for task_instance in action.entities
@@ -282,7 +280,7 @@ def handle_bulk_update(
282280

283281
# Remove duplicates while preserving order
284282
seen = set()
285-
unique_tis: list[TaskInstance] = []
283+
unique_tis: list[TI] = []
286284
for ti in all_updated_tis:
287285
ti_key = (ti.dag_id, ti.run_id, ti.task_id, ti.map_index if ti.map_index is not None else -1)
288286
if ti_key not in seen:
@@ -329,11 +327,11 @@ def handle_bulk_delete(
329327
for task_id, map_index in matched_task_keys:
330328
result = (
331329
self.session.execute(
332-
select(TaskInstance).where(
333-
TaskInstance.task_id == task_id,
334-
TaskInstance.dag_id == self.dag_id,
335-
TaskInstance.run_id == self.dag_run_id,
336-
TaskInstance.map_index == map_index,
330+
select(TI).where(
331+
TI.task_id == task_id,
332+
TI.dag_id == self.dag_id,
333+
TI.run_id == self.dag_run_id,
334+
TI.map_index == map_index,
337335
)
338336
)
339337
.scalars()
@@ -348,10 +346,10 @@ def handle_bulk_delete(
348346
# Handle deletion of all map indexes for certain task_ids
349347
for task_id in delete_all_map_indexes:
350348
all_task_instances = self.session.scalars(
351-
select(TaskInstance).where(
352-
TaskInstance.task_id == task_id,
353-
TaskInstance.dag_id == self.dag_id,
354-
TaskInstance.run_id == self.dag_run_id,
349+
select(TI).where(
350+
TI.task_id == task_id,
351+
TI.dag_id == self.dag_id,
352+
TI.run_id == self.dag_run_id,
355353
)
356354
).all()
357355

0 commit comments

Comments
 (0)