4444from airflow .api_fastapi .core_api .security import GetUserDep
4545from airflow .api_fastapi .core_api .services .public .common import BulkService
4646from airflow .listeners .listener import get_listener_manager
47- from airflow .models .taskinstance import TaskInstance
47+ from airflow .models .taskinstance import TaskInstance as TI
4848from airflow .serialization .serialized_objects import SerializedDAG
4949from airflow .utils .state import TaskInstanceState
5050
@@ -60,23 +60,21 @@ def _patch_ti_validate_request(
6060 session : SessionDep ,
6161 map_index : int | None = - 1 ,
6262 update_mask : list [str ] | None = Query (None ),
63- ) -> tuple [SerializedDAG , list [TaskInstance ], dict ]:
63+ ) -> tuple [SerializedDAG , list [TI ], dict ]:
6464 dag = get_latest_version_of_dag (dag_bag , dag_id , session )
6565 if not dag .has_task (task_id ):
6666 raise HTTPException (status .HTTP_404_NOT_FOUND , f"Task '{ task_id } ' not found in DAG '{ dag_id } '" )
6767
6868 query = (
69- select (TaskInstance )
70- .where (
71- TaskInstance .dag_id == dag_id , TaskInstance .run_id == dag_run_id , TaskInstance .task_id == task_id
72- )
73- .join (TaskInstance .dag_run )
74- .options (joinedload (TaskInstance .rendered_task_instance_fields ))
69+ select (TI )
70+ .where (TI .dag_id == dag_id , TI .run_id == dag_run_id , TI .task_id == task_id )
71+ .join (TI .dag_run )
72+ .options (joinedload (TI .rendered_task_instance_fields ))
7573 )
7674 if map_index is not None :
77- query = query .where (TaskInstance .map_index == map_index )
75+ query = query .where (TI .map_index == map_index )
7876 else :
79- query = query .order_by (TaskInstance .map_index )
77+ query = query .order_by (TI .map_index )
8078
8179 tis = session .scalars (query ).all ()
8280
@@ -106,7 +104,7 @@ def _patch_task_instance_state(
106104 data : dict ,
107105 session : Session ,
108106 commit : bool ,
109- ) -> list [TaskInstance ]:
107+ ) -> list [TI ]:
110108 map_index = getattr (task_instance_body , "map_index" , None )
111109 map_indexes = None if map_index is None else [map_index ]
112110
@@ -148,7 +146,7 @@ def _patch_task_instance_state(
148146
149147def _patch_task_instance_note (
150148 task_instance_body : BulkTaskInstanceBody | PatchTaskInstanceBody ,
151- tis : list [TaskInstance ],
149+ tis : list [TI ],
152150 user : GetUserDep ,
153151 update_mask : list [str ] | None = Query (None ),
154152) -> None :
@@ -183,17 +181,17 @@ def __init__(
183181
184182 def categorize_task_instances (
185183 self , task_keys : set [tuple [str , int ]]
186- ) -> tuple [dict [tuple [str , int ], TaskInstance ], set [tuple [str , int ]], set [tuple [str , int ]]]:
184+ ) -> tuple [dict [tuple [str , int ], TI ], set [tuple [str , int ]], set [tuple [str , int ]]]:
187185 """
188186 Categorize the given task_ids into matched_task_keys and not_found_task_keys based on existing task_ids.
189187
190188 :param task_keys: set of task_keys (tuple of task_id and map_index)
191189 :return: tuple of (task_instances_map, matched_task_keys, not_found_task_keys)
192190 """
193- query = select (TaskInstance ).where (
194- TaskInstance .dag_id == self .dag_id ,
195- TaskInstance .run_id == self .dag_run_id ,
196- TaskInstance .task_id .in_ ([task_id for task_id , _ in task_keys ]),
191+ query = select (TI ).where (
192+ TI .dag_id == self .dag_id ,
193+ TI .run_id == self .dag_run_id ,
194+ TI .task_id .in_ ([task_id for task_id , _ in task_keys ]),
197195 )
198196 task_instances = self .session .scalars (query ).all ()
199197 task_instances_map = {
@@ -221,7 +219,7 @@ def handle_bulk_update(
221219 self , action : BulkUpdateAction [BulkTaskInstanceBody ], results : BulkActionResponse
222220 ) -> TaskInstanceCollectionResponse :
223221 """Bulk Update Task Instances."""
224- all_updated_tis : list [TaskInstance ] = []
222+ all_updated_tis : list [TI ] = []
225223 to_update_task_keys = {
226224 (task_instance .task_id , task_instance .map_index if task_instance .map_index is not None else - 1 )
227225 for task_instance in action .entities
@@ -282,7 +280,7 @@ def handle_bulk_update(
282280
283281 # Remove duplicates while preserving order
284282 seen = set ()
285- unique_tis : list [TaskInstance ] = []
283+ unique_tis : list [TI ] = []
286284 for ti in all_updated_tis :
287285 ti_key = (ti .dag_id , ti .run_id , ti .task_id , ti .map_index if ti .map_index is not None else - 1 )
288286 if ti_key not in seen :
@@ -329,11 +327,11 @@ def handle_bulk_delete(
329327 for task_id , map_index in matched_task_keys :
330328 result = (
331329 self .session .execute (
332- select (TaskInstance ).where (
333- TaskInstance .task_id == task_id ,
334- TaskInstance .dag_id == self .dag_id ,
335- TaskInstance .run_id == self .dag_run_id ,
336- TaskInstance .map_index == map_index ,
330+ select (TI ).where (
331+ TI .task_id == task_id ,
332+ TI .dag_id == self .dag_id ,
333+ TI .run_id == self .dag_run_id ,
334+ TI .map_index == map_index ,
337335 )
338336 )
339337 .scalars ()
@@ -348,10 +346,10 @@ def handle_bulk_delete(
348346 # Handle deletion of all map indexes for certain task_ids
349347 for task_id in delete_all_map_indexes :
350348 all_task_instances = self .session .scalars (
351- select (TaskInstance ).where (
352- TaskInstance .task_id == task_id ,
353- TaskInstance .dag_id == self .dag_id ,
354- TaskInstance .run_id == self .dag_run_id ,
349+ select (TI ).where (
350+ TI .task_id == task_id ,
351+ TI .dag_id == self .dag_id ,
352+ TI .run_id == self .dag_run_id ,
355353 )
356354 ).all ()
357355
0 commit comments