diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index e0d83575611b0..8ebb2b44e2bb3 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -373,6 +373,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: raise AlreadyExists(detail=f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: '{run_id}' already exists") +@mark_fastapi_migration_done @security.requires_access_dag("PUT", DagAccessEntity.RUN) @provide_session @action_logging diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 3d1b1611ad8c9..ae5fc9e117738 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -1156,6 +1156,78 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + patch: + tags: + - DagRun + summary: Patch Dag Run State + description: Modify a DAG Run. + operationId: patch_dag_run_state + parameters: + - name: dag_id + in: path + required: true + schema: + type: string + title: Dag Id + - name: dag_run_id + in: path + required: true + schema: + type: string + title: Dag Run Id + - name: update_mask + in: query + required: false + schema: + anyOf: + - type: array + items: + type: string + - type: 'null' + title: Update Mask + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/DAGRunPatchBody' + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/DAGRunResponse' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Bad Request + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' /public/monitor/health: get: tags: @@ -2079,6 +2151,23 @@ components: - file_token title: DAGResponse description: DAG serializer for responses. + DAGRunPatchBody: + properties: + state: + $ref: '#/components/schemas/DAGRunPatchStates' + type: object + required: + - state + title: DAGRunPatchBody + description: DAG Run Serializer for PATCH requests. + DAGRunPatchStates: + type: string + enum: + - queued + - success + - failed + title: DAGRunPatchStates + description: Enum for DAG Run states when updating a DAG Run. DAGRunResponse: properties: run_id: diff --git a/airflow/api_fastapi/core_api/routes/public/dag_run.py b/airflow/api_fastapi/core_api/routes/public/dag_run.py index 035d1b7fd7dc2..02780d6088e94 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_run.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_run.py @@ -17,16 +17,25 @@ from __future__ import annotations -from fastapi import Depends, HTTPException +from fastapi import Depends, HTTPException, Query, Request from sqlalchemy import select from sqlalchemy.orm import Session from typing_extensions import Annotated +from airflow.api.common.mark_tasks import ( + set_dag_run_state_to_failed, + set_dag_run_state_to_queued, + set_dag_run_state_to_success, +) from airflow.api_fastapi.common.db.common import get_session from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc -from airflow.api_fastapi.core_api.serializers.dag_run import DAGRunResponse -from airflow.models import DagRun +from airflow.api_fastapi.core_api.serializers.dag_run import ( + DAGRunPatchBody, + DAGRunPatchStates, + DAGRunResponse, +) +from airflow.models import DAG, DagRun dag_run_router = AirflowRouter(tags=["DagRun"], prefix="/dags/{dag_id}/dagRuns") @@ -57,3 +66,45 @@ async def delete_dag_run(dag_id: str, dag_run_id: str, session: Annotated[Sessio ) session.delete(dag_run) + + +@dag_run_router.patch("/{dag_run_id}", responses=create_openapi_http_exception_doc([400, 401, 403, 404])) +async def patch_dag_run_state( + dag_id: str, + dag_run_id: str, + patch_body: DAGRunPatchBody, + session: Annotated[Session, Depends(get_session)], + request: Request, + update_mask: list[str] | None = Query(None), +) -> DAGRunResponse: + """Modify a DAG Run.""" + dag_run = session.scalar(select(DagRun).filter_by(dag_id=dag_id, run_id=dag_run_id)) + if dag_run is None: + raise HTTPException( + 404, f"The DagRun with dag_id: `{dag_id}` and run_id: `{dag_run_id}` was not found" + ) + + dag: DAG = request.app.state.dag_bag.get_dag(dag_id) + + if not dag: + raise HTTPException(404, f"Dag with id {dag_id} was not found") + + if update_mask: + if update_mask != ["state"]: + raise HTTPException(400, "Only `state` field can be updated through the REST API") + else: + update_mask = ["state"] + + for attr_name in update_mask: + if attr_name == "state": + state = getattr(patch_body, attr_name) + if state == DAGRunPatchStates.SUCCESS: + set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, commit=True) + elif state == DAGRunPatchStates.QUEUED: + set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id, commit=True) + else: + set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id, commit=True) + + dag_run = session.get(DagRun, dag_run.id) + + return DAGRunResponse.model_validate(dag_run, from_attributes=True) diff --git a/airflow/api_fastapi/core_api/serializers/dag_run.py b/airflow/api_fastapi/core_api/serializers/dag_run.py index 4622fac645c07..15576905611c3 100644 --- a/airflow/api_fastapi/core_api/serializers/dag_run.py +++ b/airflow/api_fastapi/core_api/serializers/dag_run.py @@ -18,6 +18,7 @@ from __future__ import annotations from datetime import datetime +from enum import Enum from pydantic import BaseModel, Field @@ -25,6 +26,20 @@ from airflow.utils.types import DagRunTriggeredByType, DagRunType +class DAGRunPatchStates(str, Enum): + """Enum for DAG Run states when updating a DAG Run.""" + + QUEUED = DagRunState.QUEUED + SUCCESS = DagRunState.SUCCESS + FAILED = DagRunState.FAILED + + +class DAGRunPatchBody(BaseModel): + """DAG Run Serializer for PATCH requests.""" + + state: DAGRunPatchStates + + class DAGRunResponse(BaseModel): """DAG Run serializer for responses.""" diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index 959b476718d22..b8c83b1525b22 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -441,6 +441,9 @@ export type DagServicePatchDagMutationResult = Awaited< export type VariableServicePatchVariableMutationResult = Awaited< ReturnType >; +export type DagRunServicePatchDagRunStateMutationResult = Awaited< + ReturnType +>; export type PoolServicePatchPoolMutationResult = Awaited< ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index a3aed2e793718..a0d6a6585304f 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -23,6 +23,7 @@ import { } from "../requests/services.gen"; import { DAGPatchBody, + DAGRunPatchBody, DagRunState, PoolPatchBody, PoolPostBody, @@ -948,6 +949,57 @@ export const useVariableServicePatchVariable = < }) as unknown as Promise, ...options, }); +/** + * Patch Dag Run State + * Modify a DAG Run. + * @param data The data for the request. + * @param data.dagId + * @param data.dagRunId + * @param data.requestBody + * @param data.updateMask + * @returns DAGRunResponse Successful Response + * @throws ApiError + */ +export const useDagRunServicePatchDagRunState = < + TData = Common.DagRunServicePatchDagRunStateMutationResult, + TError = unknown, + TContext = unknown, +>( + options?: Omit< + UseMutationOptions< + TData, + TError, + { + dagId: string; + dagRunId: string; + requestBody: DAGRunPatchBody; + updateMask?: string[]; + }, + TContext + >, + "mutationFn" + >, +) => + useMutation< + TData, + TError, + { + dagId: string; + dagRunId: string; + requestBody: DAGRunPatchBody; + updateMask?: string[]; + }, + TContext + >({ + mutationFn: ({ dagId, dagRunId, requestBody, updateMask }) => + DagRunService.patchDagRunState({ + dagId, + dagRunId, + requestBody, + updateMask, + }) as unknown as Promise, + ...options, + }); /** * Patch Pool * Update a Pool. diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index b1a5b267e11ef..3982407c5f181 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -887,6 +887,25 @@ export const $DAGResponse = { description: "DAG serializer for responses.", } as const; +export const $DAGRunPatchBody = { + properties: { + state: { + $ref: "#/components/schemas/DAGRunPatchStates", + }, + }, + type: "object", + required: ["state"], + title: "DAGRunPatchBody", + description: "DAG Run Serializer for PATCH requests.", +} as const; + +export const $DAGRunPatchStates = { + type: "string", + enum: ["queued", "success", "failed"], + title: "DAGRunPatchStates", + description: "Enum for DAG Run states when updating a DAG Run.", +} as const; + export const $DAGRunResponse = { properties: { run_id: { diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 4db1e052a202a..56207631d154d 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -43,6 +43,8 @@ import type { GetDagRunResponse, DeleteDagRunData, DeleteDagRunResponse, + PatchDagRunStateData, + PatchDagRunStateResponse, GetHealthResponse, DeletePoolData, DeletePoolResponse, @@ -672,6 +674,42 @@ export class DagRunService { }, }); } + + /** + * Patch Dag Run State + * Modify a DAG Run. + * @param data The data for the request. + * @param data.dagId + * @param data.dagRunId + * @param data.requestBody + * @param data.updateMask + * @returns DAGRunResponse Successful Response + * @throws ApiError + */ + public static patchDagRunState( + data: PatchDagRunStateData, + ): CancelablePromise { + return __request(OpenAPI, { + method: "PATCH", + url: "/public/dags/{dag_id}/dagRuns/{dag_run_id}", + path: { + dag_id: data.dagId, + dag_run_id: data.dagRunId, + }, + query: { + update_mask: data.updateMask, + }, + body: data.requestBody, + mediaType: "application/json", + errors: { + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Not Found", + 422: "Validation Error", + }, + }); + } } export class MonitorService { diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index 288916b3928ec..a3b1d8e6bef68 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -154,6 +154,18 @@ export type DAGResponse = { readonly file_token: string; }; +/** + * DAG Run Serializer for PATCH requests. + */ +export type DAGRunPatchBody = { + state: DAGRunPatchStates; +}; + +/** + * Enum for DAG Run states when updating a DAG Run. + */ +export type DAGRunPatchStates = "queued" | "success" | "failed"; + /** * DAG Run serializer for responses. */ @@ -680,6 +692,15 @@ export type DeleteDagRunData = { export type DeleteDagRunResponse = void; +export type PatchDagRunStateData = { + dagId: string; + dagRunId: string; + requestBody: DAGRunPatchBody; + updateMask?: Array | null; +}; + +export type PatchDagRunStateResponse = DAGRunResponse; + export type GetHealthResponse = HealthInfoSchema; export type DeletePoolData = { @@ -1236,6 +1257,35 @@ export type $OpenApiTs = { 422: HTTPValidationError; }; }; + patch: { + req: PatchDagRunStateData; + res: { + /** + * Successful Response + */ + 200: DAGRunResponse; + /** + * Bad Request + */ + 400: HTTPExceptionResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; }; "/public/monitor/health": { get: { diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_run.py b/tests/api_fastapi/core_api/routes/public/test_dag_run.py index 554bc73ebab4f..dfd48af2fa530 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_run.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_run.py @@ -138,6 +138,56 @@ def test_get_dag_run_not_found(self, test_client): assert body["detail"] == "The DagRun with dag_id: `test_dag1` and run_id: `invalid` was not found" +class TestPatchDagRun: + @pytest.mark.parametrize( + "dag_id, run_id, state, response_state", + [ + (DAG1_ID, DAG1_RUN1_ID, DagRunState.FAILED, DagRunState.FAILED), + (DAG1_ID, DAG1_RUN2_ID, DagRunState.SUCCESS, DagRunState.SUCCESS), + (DAG2_ID, DAG2_RUN1_ID, DagRunState.QUEUED, DagRunState.QUEUED), + ], + ) + def test_patch_dag_run(self, test_client, dag_id, run_id, state, response_state): + response = test_client.patch(f"/public/dags/{dag_id}/dagRuns/{run_id}", json={"state": state}) + assert response.status_code == 200 + body = response.json() + assert body["dag_id"] == dag_id + assert body["run_id"] == run_id + assert body["state"] == response_state + + @pytest.mark.parametrize( + "query_params,patch_body, expected_status_code", + [ + ({"update_mask": ["state"]}, {"state": DagRunState.SUCCESS}, 200), + ({}, {"state": DagRunState.SUCCESS}, 200), + ({"update_mask": ["random"]}, {"state": DagRunState.SUCCESS}, 400), + ], + ) + def test_patch_dag_run_with_update_mask( + self, test_client, query_params, patch_body, expected_status_code + ): + response = test_client.patch( + f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}", params=query_params, json=patch_body + ) + assert response.status_code == expected_status_code + + def test_patch_dag_run_not_found(self, test_client): + response = test_client.patch( + f"/public/dags/{DAG1_ID}/dagRuns/invalid", json={"state": DagRunState.SUCCESS} + ) + assert response.status_code == 404 + body = response.json() + assert body["detail"] == "The DagRun with dag_id: `test_dag1` and run_id: `invalid` was not found" + + def test_patch_dag_run_bad_request(self, test_client): + response = test_client.patch( + f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}", json={"state": "running"} + ) + assert response.status_code == 422 + body = response.json() + assert body["detail"][0]["msg"] == "Input should be 'queued', 'success' or 'failed'" + + class TestDeleteDagRun: def test_delete_dag_run(self, test_client): response = test_client.delete(f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}")