From c4dca221ae5fc6f9a0828443858a65543bd56fbc Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Sun, 24 Nov 2024 23:37:15 +0100 Subject: [PATCH 1/3] Migrate Edge calls for Worker to FastAPI 2 - Logs route --- .../src/airflow/providers/edge/CHANGELOG.rst | 8 + .../src/airflow/providers/edge/__init__.py | 2 +- .../airflow/providers/edge/cli/api_client.py | 31 ++- .../providers/edge/cli/edge_command.py | 16 +- .../providers/edge/models/edge_logs.py | 2 +- .../edge/openapi/edge_worker_api_v1.yaml | 197 ++++++++++++++++++ .../src/airflow/providers/edge/provider.yaml | 2 +- .../airflow/providers/edge/worker_api/app.py | 2 + .../providers/edge/worker_api/datamodels.py | 22 ++ .../edge/worker_api/routes/_v2_routes.py | 47 ++++- .../providers/edge/worker_api/routes/logs.py | 133 ++++++++++++ providers/tests/edge/cli/test_edge_command.py | 41 ++-- providers/tests/edge/models/test_edge_logs.py | 49 ----- .../tests/edge/worker_api/routes/test_logs.py | 75 +++++++ 14 files changed, 542 insertions(+), 85 deletions(-) create mode 100644 providers/src/airflow/providers/edge/worker_api/routes/logs.py delete mode 100644 providers/tests/edge/models/test_edge_logs.py create mode 100644 providers/tests/edge/worker_api/routes/test_logs.py diff --git a/providers/src/airflow/providers/edge/CHANGELOG.rst b/providers/src/airflow/providers/edge/CHANGELOG.rst index 8309f111f6a3f..93900dfeb5086 100644 --- a/providers/src/airflow/providers/edge/CHANGELOG.rst +++ b/providers/src/airflow/providers/edge/CHANGELOG.rst @@ -27,6 +27,14 @@ Changelog --------- +0.8.1pre0 +......... + +Misc +~~~~ + +* ``Migrate worker log calls to FastAPI.`` + 0.8.0pre0 ......... diff --git a/providers/src/airflow/providers/edge/__init__.py b/providers/src/airflow/providers/edge/__init__.py index fd23acee829cc..8c53d0bed1be2 100644 --- a/providers/src/airflow/providers/edge/__init__.py +++ b/providers/src/airflow/providers/edge/__init__.py @@ -29,7 +29,7 @@ __all__ = ["__version__"] -__version__ = "0.8.0pre0" +__version__ = "0.8.1pre0" if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse( "2.10.0" diff --git a/providers/src/airflow/providers/edge/cli/api_client.py b/providers/src/airflow/providers/edge/cli/api_client.py index 9174191fd8c35..9b5781e359d9e 100644 --- a/providers/src/airflow/providers/edge/cli/api_client.py +++ b/providers/src/airflow/providers/edge/cli/api_client.py @@ -32,9 +32,10 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.providers.edge.worker_api.auth import jwt_signer -from airflow.providers.edge.worker_api.datamodels import WorkerStateBody +from airflow.providers.edge.worker_api.datamodels import PushLogsBody, WorkerStateBody if TYPE_CHECKING: + from airflow.models.taskinstancekey import TaskInstanceKey from airflow.providers.edge.models.edge_worker import EdgeWorkerState logger = logging.getLogger(__name__) @@ -64,7 +65,7 @@ def _is_retryable_exception(exception: BaseException) -> bool: retry=tenacity.retry_if_exception(_is_retryable_exception), before_sleep=tenacity.before_log(logger, logging.WARNING), ) -def _make_generic_request(method: str, rest_path: str, data: str) -> Any: +def _make_generic_request(method: str, rest_path: str, data: str | None = None) -> Any: signer = jwt_signer() api_url = conf.get("edge", "api_url") path = urlparse(api_url).path.replace("/rpcapi", "") @@ -104,11 +105,33 @@ def worker_set_state( hostname: str, state: EdgeWorkerState, jobs_active: int, queues: list[str] | None, sysinfo: dict ) -> list[str] | None: """Register worker with the Edge API.""" - result = _make_generic_request( + return _make_generic_request( "PATCH", f"worker/{quote(hostname)}", WorkerStateBody(state=state, jobs_active=jobs_active, queues=queues, sysinfo=sysinfo).model_dump_json( exclude_unset=True ), ) - return result + + +def logs_logfile_path(task: TaskInstanceKey) -> Path: + """Elaborate the path and filename to expect from task execution.""" + result = _make_generic_request( + "GET", + f"logs/logfile_path/{task.dag_id}/{task.task_id}/{task.run_id}/{task.try_number}/{task.map_index}", + ) + base_log_folder = conf.get("logging", "base_log_folder", fallback="NOT AVAILABLE") + return Path(base_log_folder, result) + + +def logs_push( + task: TaskInstanceKey, + log_chunk_time: datetime, + log_chunk_data: str, +) -> None: + """Push an incremental log chunk from Edge Worker to central site.""" + _make_generic_request( + "POST", + f"logs/push/{task.dag_id}/{task.task_id}/{task.run_id}/{task.try_number}/{task.map_index}", + PushLogsBody(log_chunk_time=log_chunk_time, log_chunk_data=log_chunk_data).model_dump_json(), + ) diff --git a/providers/src/airflow/providers/edge/cli/edge_command.py b/providers/src/airflow/providers/edge/cli/edge_command.py index 9d172bffdd5b9..f175d6e77b3c4 100644 --- a/providers/src/airflow/providers/edge/cli/edge_command.py +++ b/providers/src/airflow/providers/edge/cli/edge_command.py @@ -36,11 +36,15 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.providers.edge import __version__ as edge_provider_version -from airflow.providers.edge.cli.api_client import worker_register, worker_set_state +from airflow.providers.edge.cli.api_client import ( + logs_logfile_path, + logs_push, + worker_register, + worker_set_state, +) from airflow.providers.edge.models.edge_job import EdgeJob -from airflow.providers.edge.models.edge_logs import EdgeLogs from airflow.providers.edge.models.edge_worker import EdgeWorkerState, EdgeWorkerVersionException -from airflow.utils import cli as cli_utils +from airflow.utils import cli as cli_utils, timezone from airflow.utils.platform import IS_WINDOWS from airflow.utils.providers_configuration_loader import providers_configuration_loaded from airflow.utils.state import TaskInstanceState @@ -246,7 +250,7 @@ def fetch_job(self) -> bool: env["AIRFLOW__CORE__INTERNAL_API_URL"] = conf.get("edge", "api_url") env["_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK"] = "1" process = Popen(edge_job.command, close_fds=True, env=env, start_new_session=True) - logfile = EdgeLogs.logfile_path(edge_job.key) + logfile = logs_logfile_path(edge_job.key) self.jobs.append(_Job(edge_job, process, logfile, 0)) EdgeJob.set_state(edge_job.key, TaskInstanceState.RUNNING) return True @@ -285,9 +289,9 @@ def check_running_jobs(self) -> None: if not chunk_data: break - EdgeLogs.push_logs( + logs_push( task=job.edge_job.key, - log_chunk_time=datetime.now(), + log_chunk_time=timezone.utcnow(), log_chunk_data=chunk_data, ) diff --git a/providers/src/airflow/providers/edge/models/edge_logs.py b/providers/src/airflow/providers/edge/models/edge_logs.py index 29625f5be757b..65146cf7edc3b 100644 --- a/providers/src/airflow/providers/edge/models/edge_logs.py +++ b/providers/src/airflow/providers/edge/models/edge_logs.py @@ -87,7 +87,7 @@ def __init__( class EdgeLogs(BaseModel, LoggingMixin): - """Accessor for Edge Worker instances as logical model.""" + """Deprecated Internal API for Edge Worker instances as logical model.""" dag_id: str task_id: str diff --git a/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml b/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml index 8be23c0d07cc3..7915bdb5b4aa4 100644 --- a/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml +++ b/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml @@ -178,6 +178,186 @@ paths: summary: Register tags: - Worker + /logs/logfile_path/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}: + get: + description: Elaborate the path and filename to expect from task execution. + x-openapi-router-controller: airflow.providers.edge.worker_api.routes._v2_routes + operationId: logfile_path_v2 + parameters: + - description: Identifier of the DAG to which the task belongs. + in: path + name: dag_id + required: true + schema: + description: Identifier of the DAG to which the task belongs. + title: Dag ID + type: string + - description: Task name in the DAG. + in: path + name: task_id + required: true + schema: + description: Task name in the DAG. + title: Task ID + type: string + - description: Run ID of the DAG execution. + in: path + name: run_id + required: true + schema: + description: Run ID of the DAG execution. + title: Run ID + type: string + - description: The number of attempt to execute this task. + in: path + name: try_number + required: true + schema: + description: The number of attempt to execute this task. + title: Try Number + type: integer + - description: For dynamically mapped tasks the mapping number, -1 if the task + is not mapped. + in: path + name: map_index + required: true + schema: + description: For dynamically mapped tasks the mapping number, -1 if the + task is not mapped. + title: Map Index + type: string # This should be integer, but Connexion/Flask do not support negative integers in path parameters + - description: JWT Authorization Token + in: header + name: authorization + required: true + schema: + description: JWT Authorization Token + title: Authorization + type: string + responses: + '200': + content: + application/json: + schema: + title: Response Logfile Path + type: string + description: Successful Response + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Bad Request + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '422': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + description: Validation Error + summary: Logfile Path + tags: + - Logs + /logs/push/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}: + post: + description: Push an incremental log chunk from Edge Worker to central site. + x-openapi-router-controller: airflow.providers.edge.worker_api.routes._v2_routes + operationId: push_logs_v2 + parameters: + - description: Identifier of the DAG to which the task belongs. + in: path + name: dag_id + required: true + schema: + description: Identifier of the DAG to which the task belongs. + title: Dag ID + type: string + - description: Task name in the DAG. + in: path + name: task_id + required: true + schema: + description: Task name in the DAG. + title: Task ID + type: string + - description: Run ID of the DAG execution. + in: path + name: run_id + required: true + schema: + description: Run ID of the DAG execution. + title: Run ID + type: string + - description: The number of attempt to execute this task. + in: path + name: try_number + required: true + schema: + description: The number of attempt to execute this task. + title: Try Number + type: integer + - description: For dynamically mapped tasks the mapping number, -1 if the task + is not mapped. + in: path + name: map_index + required: true + schema: + description: For dynamically mapped tasks the mapping number, -1 if the + task is not mapped. + title: Map Index + type: string # This should be integer, but Connexion/Flask do not support negative integers in path parameters + - description: JWT Authorization Token + in: header + name: authorization + required: true + schema: + description: JWT Authorization Token + title: Authorization + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/PushLogsBody' + description: The worker remote has no access to log sink and with this + can send log chunks to the central site. + title: Log data chunks + required: true + responses: + '200': + content: + application/json: + schema: + title: Response Push Logs + type: object + nullable: true + description: Successful Response + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Bad Request + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '422': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + description: Validation Error + summary: Push Logs + tags: + - Logs /rpcapi: post: deprecated: false @@ -284,6 +464,23 @@ components: title: Sysinfo type: object title: WorkerStateBody + PushLogsBody: + description: Incremental new log content from worker. + properties: + log_chunk_data: + description: Log chunk data as incremental log text. + title: Log Chunk Data + type: string + log_chunk_time: + description: Time of the log chunk at point of sending. + format: date-time + title: Log Chunk Time + type: string + required: + - log_chunk_time + - log_chunk_data + title: PushLogsBody + type: object HTTPExceptionResponse: description: HTTPException Model used for error response. properties: diff --git a/providers/src/airflow/providers/edge/provider.yaml b/providers/src/airflow/providers/edge/provider.yaml index 25dd75a2624c5..c4c289b228a58 100644 --- a/providers/src/airflow/providers/edge/provider.yaml +++ b/providers/src/airflow/providers/edge/provider.yaml @@ -27,7 +27,7 @@ source-date-epoch: 1729683247 # note that those versions are maintained by release manager - do not update them manually versions: - - 0.8.0pre0 + - 0.8.1pre0 dependencies: - apache-airflow>=2.10.0 diff --git a/providers/src/airflow/providers/edge/worker_api/app.py b/providers/src/airflow/providers/edge/worker_api/app.py index 69a43edb116bf..e90c5c4709642 100644 --- a/providers/src/airflow/providers/edge/worker_api/app.py +++ b/providers/src/airflow/providers/edge/worker_api/app.py @@ -19,6 +19,7 @@ from fastapi import FastAPI from airflow.providers.edge.worker_api.routes.health import health_router +from airflow.providers.edge.worker_api.routes.logs import logs_router from airflow.providers.edge.worker_api.routes.worker import worker_router @@ -35,5 +36,6 @@ def create_edge_worker_api_app() -> FastAPI: ) edge_worker_api_app.include_router(health_router) + edge_worker_api_app.include_router(logs_router) edge_worker_api_app.include_router(worker_router) return edge_worker_api_app diff --git a/providers/src/airflow/providers/edge/worker_api/datamodels.py b/providers/src/airflow/providers/edge/worker_api/datamodels.py index 170d8c449ffc3..f4455c7e1e278 100644 --- a/providers/src/airflow/providers/edge/worker_api/datamodels.py +++ b/providers/src/airflow/providers/edge/worker_api/datamodels.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +from datetime import datetime from typing import ( # noqa: UP035 - prevent pytest failing in back-compat Annotated, Any, @@ -28,6 +29,20 @@ from pydantic import BaseModel, Field from airflow.providers.edge.models.edge_worker import EdgeWorkerState # noqa: TCH001 +from airflow.providers.edge.worker_api.routes._v2_compat import Path + + +class WorkerApiDocs: + """Documentation collection for the worker API.""" + + dag_id = Path(title="Dag ID", description="Identifier of the DAG to which the task belongs.") + task_id = Path(title="Task ID", description="Task name in the DAG.") + run_id = Path(title="Run ID", description="Run ID of the DAG execution.") + try_number = Path(title="Try Number", description="The number of attempt to execute this task.") + map_index = Path( + title="Map Index", + description="For dynamically mapped tasks the mapping number, -1 if the task is not mapped.", + ) class JsonRpcRequestBase(BaseModel): @@ -86,3 +101,10 @@ class WorkerQueueUpdateBody(BaseModel): Optional[List[str]], # noqa: UP006, UP007 - prevent pytest failing in back-compat Field(description="Queues to remove from worker."), ] + + +class PushLogsBody(BaseModel): + """Incremental new log content from worker.""" + + log_chunk_time: Annotated[datetime, Field(description="Time of the log chunk at point of sending.")] + log_chunk_data: Annotated[str, Field(description="Log chunk data as incremental log text.")] diff --git a/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py b/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py index 6f2e81caa0026..f155c63e96a5b 100644 --- a/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py +++ b/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py @@ -24,10 +24,13 @@ from typing import TYPE_CHECKING, Any, Callable from uuid import uuid4 +from flask import Response, request + from airflow.exceptions import AirflowException from airflow.providers.edge.worker_api.auth import jwt_token_authorization, jwt_token_authorization_rpc -from airflow.providers.edge.worker_api.datamodels import JsonRpcRequest, WorkerStateBody +from airflow.providers.edge.worker_api.datamodels import JsonRpcRequest, PushLogsBody, WorkerStateBody from airflow.providers.edge.worker_api.routes._v2_compat import HTTPException, status +from airflow.providers.edge.worker_api.routes.logs import logfile_path, push_logs from airflow.providers.edge.worker_api.routes.worker import register, set_state from airflow.serialization.serialized_objects import BaseSerialization from airflow.utils.session import NEW_SESSION, create_session, provide_session @@ -184,8 +187,6 @@ def rpcapi_v2(body: dict[str, Any]) -> APIResponse: # Note: Except the method map this _was_ a 100% copy of internal API module # airflow.api_internal.endpoints.rpc_api_endpoint.internal_airflow_api() # As of rework for FastAPI in Airflow 3.0, this is updated and to be removed in the future. - from flask import Response, request - try: if request.headers.get("Content-Type", "") != "application/json": raise HTTPException(status.HTTP_403_FORBIDDEN, "Expected Content-Type: application/json") @@ -238,8 +239,6 @@ def rpcapi_v2(body: dict[str, Any]) -> APIResponse: @provide_session def register_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION) -> Any: """Handle Edge Worker API `/edge_worker/v1/worker/{worker_name}` endpoint for Airflow 2.10.""" - from flask import request - try: auth = request.headers.get("Authorization", "") jwt_token_authorization(request.path, auth) @@ -254,8 +253,6 @@ def register_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION) -> @provide_session def set_state_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION) -> Any: """Handle Edge Worker API `/edge_worker/v1/worker/{worker_name}` endpoint for Airflow 2.10.""" - from flask import request - try: auth = request.headers.get("Authorization", "") jwt_token_authorization(request.path, auth) @@ -268,3 +265,39 @@ def set_state_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION) -> return set_state(worker_name, request_obj, session) except HTTPException as e: return e.to_response() # type: ignore[attr-defined] + + +def logfile_path_v2( + dag_id: str, + task_id: str, + run_id: str, + try_number: int, + map_index: str, # Note: Connexion can not have negative numbers in path parameters, use string therefore +) -> str: + """Handle Edge Worker API `/edge_worker/v1/logs/logfile_path/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}` endpoint for Airflow 2.10.""" + try: + auth = request.headers.get("Authorization", "") + jwt_token_authorization(request.path, auth) + return logfile_path(dag_id, task_id, run_id, try_number, int(map_index)) + except HTTPException as e: + return e.to_response() # type: ignore[attr-defined] + + +def push_logs_v2( + dag_id: str, + task_id: str, + run_id: str, + try_number: int, + map_index: str, # Note: Connexion can not have negative numbers in path parameters, use string therefore + body: dict[str, Any], +) -> None: + """Handle Edge Worker API `/edge_worker/v1/logs/push/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}` endpoint for Airflow 2.10.""" + try: + auth = request.headers.get("Authorization", "") + jwt_token_authorization(request.path, auth) + request_obj = PushLogsBody( + log_chunk_data=body["log_chunk_data"], log_chunk_time=body["log_chunk_time"] + ) + push_logs(dag_id, task_id, run_id, try_number, int(map_index), request_obj) + except HTTPException as e: + return e.to_response() # type: ignore[attr-defined] diff --git a/providers/src/airflow/providers/edge/worker_api/routes/logs.py b/providers/src/airflow/providers/edge/worker_api/routes/logs.py new file mode 100644 index 0000000000000..1e5f9be7db677 --- /dev/null +++ b/providers/src/airflow/providers/edge/worker_api/routes/logs.py @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from functools import cache +from pathlib import Path +from typing import TYPE_CHECKING, Annotated + +from airflow.configuration import conf +from airflow.models.taskinstance import TaskInstance +from airflow.models.taskinstancekey import TaskInstanceKey +from airflow.providers.edge.models.edge_logs import EdgeLogsModel +from airflow.providers.edge.worker_api.auth import jwt_token_authorization_rest +from airflow.providers.edge.worker_api.datamodels import PushLogsBody, WorkerApiDocs +from airflow.providers.edge.worker_api.routes._v2_compat import ( + AirflowRouter, + Body, + Depends, + create_openapi_http_exception_doc, + status, +) +from airflow.utils.session import NEW_SESSION, create_session, provide_session + +logs_router = AirflowRouter(tags=["Logs"], prefix="/logs") + + +@cache +@provide_session +def _logfile_path(task: TaskInstanceKey, session=NEW_SESSION) -> str: + """Elaborate the (relative) path and filename to expect from task execution.""" + from airflow.utils.log.file_task_handler import FileTaskHandler + + ti = TaskInstance.get_task_instance( + dag_id=task.dag_id, + run_id=task.run_id, + task_id=task.task_id, + map_index=task.map_index, + session=session, + ) + if TYPE_CHECKING: + assert ti + return FileTaskHandler(".")._render_filename(ti, task.try_number) + + +@logs_router.get( + "/logfile_path/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}", + dependencies=[Depends(jwt_token_authorization_rest)], + responses=create_openapi_http_exception_doc( + [ + status.HTTP_400_BAD_REQUEST, + status.HTTP_403_FORBIDDEN, + ] + ), +) +def logfile_path( + dag_id: Annotated[str, WorkerApiDocs.dag_id], + task_id: Annotated[str, WorkerApiDocs.task_id], + run_id: Annotated[str, WorkerApiDocs.run_id], + try_number: Annotated[int, WorkerApiDocs.try_number], + map_index: Annotated[int, WorkerApiDocs.map_index], +) -> str: + """Elaborate the path and filename to expect from task execution.""" + task = TaskInstanceKey( + dag_id=dag_id, task_id=task_id, run_id=run_id, try_number=try_number, map_index=map_index + ) + return _logfile_path(task) + + +@logs_router.post( + "/push/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}", + dependencies=[Depends(jwt_token_authorization_rest)], + responses=create_openapi_http_exception_doc( + [ + status.HTTP_400_BAD_REQUEST, + status.HTTP_403_FORBIDDEN, + ] + ), +) +def push_logs( + dag_id: Annotated[str, WorkerApiDocs.dag_id], + task_id: Annotated[str, WorkerApiDocs.task_id], + run_id: Annotated[str, WorkerApiDocs.run_id], + try_number: Annotated[int, WorkerApiDocs.try_number], + map_index: Annotated[int, WorkerApiDocs.map_index], + body: Annotated[ + PushLogsBody, + Body( + title="Log data chunks", + description="The worker remote has no access to log sink and with this can send log chunks to the central site.", + ), + ], +) -> None: + """Push an incremental log chunk from Edge Worker to central site.""" + with create_session() as session: + log_chunk = EdgeLogsModel( + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + try_number=try_number, + log_chunk_time=body.log_chunk_time, + log_chunk_data=body.log_chunk_data, + ) + session.add(log_chunk) + session.commit() + # Write logs to local file to make them accessible + task = TaskInstanceKey( + dag_id=dag_id, task_id=task_id, run_id=run_id, try_number=try_number, map_index=map_index + ) + base_log_folder = conf.get("logging", "base_log_folder", fallback="NOT AVAILABLE") + logfile_path = Path(base_log_folder, _logfile_path(task)) + if not logfile_path.exists(): + new_folder_permissions = int( + conf.get("logging", "file_task_handler_new_folder_permissions", fallback="0o775"), 8 + ) + logfile_path.parent.mkdir(parents=True, exist_ok=True, mode=new_folder_permissions) + with logfile_path.open("a") as logfile: + logfile.write(body.log_chunk_data) diff --git a/providers/tests/edge/cli/test_edge_command.py b/providers/tests/edge/cli/test_edge_command.py index 3304831064abd..4f2706ef53114 100644 --- a/providers/tests/edge/cli/test_edge_command.py +++ b/providers/tests/edge/cli/test_edge_command.py @@ -30,6 +30,7 @@ from airflow.providers.edge.cli.edge_command import _EdgeWorkerCli, _Job, _write_pid_to_pidfile from airflow.providers.edge.models.edge_job import EdgeJob from airflow.providers.edge.models.edge_worker import EdgeWorkerState, EdgeWorkerVersionException +from airflow.utils import timezone from airflow.utils.state import TaskInstanceState from tests_common.test_utils.config import conf_vars @@ -146,7 +147,7 @@ def worker_with_job(self, tmp_path: Path, dummy_joblist: list[_Job]) -> _EdgeWor ], ) @patch("airflow.providers.edge.models.edge_job.EdgeJob.reserve_task") - @patch("airflow.providers.edge.models.edge_logs.EdgeLogs.logfile_path") + @patch("airflow.providers.edge.cli.edge_command.logs_logfile_path") @patch("airflow.providers.edge.models.edge_job.EdgeJob.set_state") @patch("subprocess.Popen") def test_fetch_job( @@ -201,8 +202,8 @@ def test_check_running_jobs_failed(self, mock_set_state, worker_with_job: _EdgeW assert worker_with_job.free_concurrency == worker_with_job.concurrency @time_machine.travel(datetime.now(), tick=False) - @patch("airflow.providers.edge.models.edge_logs.EdgeLogs.push_logs") - def test_check_running_jobs_log_push(self, mock_push_logs, worker_with_job: _EdgeWorkerCli): + @patch("airflow.providers.edge.cli.edge_command.logs_push") + def test_check_running_jobs_log_push(self, mock_logs_push, worker_with_job: _EdgeWorkerCli): job = worker_with_job.jobs[0] job.logfile.write_text("some log content") with conf_vars( @@ -213,13 +214,13 @@ def test_check_running_jobs_log_push(self, mock_push_logs, worker_with_job: _Edg ): worker_with_job.check_running_jobs() assert len(worker_with_job.jobs) == 1 - mock_push_logs.assert_called_once_with( - task=job.edge_job.key, log_chunk_time=datetime.now(), log_chunk_data="some log content" + mock_logs_push.assert_called_once_with( + task=job.edge_job.key, log_chunk_time=timezone.utcnow(), log_chunk_data="some log content" ) @time_machine.travel(datetime.now(), tick=False) - @patch("airflow.providers.edge.models.edge_logs.EdgeLogs.push_logs") - def test_check_running_jobs_log_push_increment(self, mock_push_logs, worker_with_job: _EdgeWorkerCli): + @patch("airflow.providers.edge.cli.edge_command.logs_push") + def test_check_running_jobs_log_push_increment(self, mock_logs_push, worker_with_job: _EdgeWorkerCli): job = worker_with_job.jobs[0] job.logfile.write_text("hello ") job.logsize = job.logfile.stat().st_size @@ -232,13 +233,13 @@ def test_check_running_jobs_log_push_increment(self, mock_push_logs, worker_with ): worker_with_job.check_running_jobs() assert len(worker_with_job.jobs) == 1 - mock_push_logs.assert_called_once_with( - task=job.edge_job.key, log_chunk_time=datetime.now(), log_chunk_data="world" + mock_logs_push.assert_called_once_with( + task=job.edge_job.key, log_chunk_time=timezone.utcnow(), log_chunk_data="world" ) @time_machine.travel(datetime.now(), tick=False) - @patch("airflow.providers.edge.models.edge_logs.EdgeLogs.push_logs") - def test_check_running_jobs_log_push_chunks(self, mock_push_logs, worker_with_job: _EdgeWorkerCli): + @patch("airflow.providers.edge.cli.edge_command.logs_push") + def test_check_running_jobs_log_push_chunks(self, mock_logs_push, worker_with_job: _EdgeWorkerCli): job = worker_with_job.jobs[0] job.logfile.write_bytes("log1log2ülog3".encode("latin-1")) with conf_vars( @@ -246,12 +247,20 @@ def test_check_running_jobs_log_push_chunks(self, mock_push_logs, worker_with_jo ): worker_with_job.check_running_jobs() assert len(worker_with_job.jobs) == 1 - calls = mock_push_logs.call_args_list + calls = mock_logs_push.call_args_list assert len(calls) == 4 - assert calls[0] == call(task=job.edge_job.key, log_chunk_time=datetime.now(), log_chunk_data="log1") - assert calls[1] == call(task=job.edge_job.key, log_chunk_time=datetime.now(), log_chunk_data="log2") - assert calls[2] == call(task=job.edge_job.key, log_chunk_time=datetime.now(), log_chunk_data="\\xfc") - assert calls[3] == call(task=job.edge_job.key, log_chunk_time=datetime.now(), log_chunk_data="log3") + assert calls[0] == call( + task=job.edge_job.key, log_chunk_time=timezone.utcnow(), log_chunk_data="log1" + ) + assert calls[1] == call( + task=job.edge_job.key, log_chunk_time=timezone.utcnow(), log_chunk_data="log2" + ) + assert calls[2] == call( + task=job.edge_job.key, log_chunk_time=timezone.utcnow(), log_chunk_data="\\xfc" + ) + assert calls[3] == call( + task=job.edge_job.key, log_chunk_time=timezone.utcnow(), log_chunk_data="log3" + ) @pytest.mark.parametrize( "drain, jobs, expected_state", diff --git a/providers/tests/edge/models/test_edge_logs.py b/providers/tests/edge/models/test_edge_logs.py deleted file mode 100644 index 0bb3307e0ca46..0000000000000 --- a/providers/tests/edge/models/test_edge_logs.py +++ /dev/null @@ -1,49 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import pytest - -from airflow.providers.edge.models.edge_logs import EdgeLogs, EdgeLogsModel -from airflow.utils import timezone - -pytestmark = pytest.mark.db_test - -pytest.importorskip("pydantic", minversion="2.0.0") - - -def test_serializing_pydantic_edge_logs(): - rlm = EdgeLogsModel( - dag_id="test_dag", - task_id="test_task", - run_id="test_run", - map_index=-1, - try_number=1, - log_chunk_time=timezone.utcnow(), - log_chunk_data="some logs captured", - ) - - pydantic_logs = EdgeLogs.model_validate(rlm) - - json_string = pydantic_logs.model_dump_json() - print(json_string) - - deserialized_model = EdgeLogs.model_validate_json(json_string) - assert deserialized_model.dag_id == rlm.dag_id - assert deserialized_model.try_number == rlm.try_number - assert deserialized_model.log_chunk_time == rlm.log_chunk_time - assert deserialized_model.log_chunk_data == rlm.log_chunk_data diff --git a/providers/tests/edge/worker_api/routes/test_logs.py b/providers/tests/edge/worker_api/routes/test_logs.py new file mode 100644 index 0000000000000..84090192b38c7 --- /dev/null +++ b/providers/tests/edge/worker_api/routes/test_logs.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from airflow.operators.empty import EmptyOperator +from airflow.providers.edge.models.edge_logs import EdgeLogsModel +from airflow.providers.edge.worker_api.datamodels import PushLogsBody +from airflow.providers.edge.worker_api.routes.logs import logfile_path, push_logs +from airflow.utils import timezone + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + +pytestmark = pytest.mark.db_test + + +DAG_ID = "my_dag" +TASK_ID = "my_task" +RUN_ID = "manual__2024-11-24T21:03:01+01:00" + + +class TestLogsApiRoutes: + @pytest.fixture(autouse=True) + def setup_test_cases(self, dag_maker, session: Session): + with dag_maker(DAG_ID): + EmptyOperator(task_id=TASK_ID) + dag_maker.create_dagrun(run_id=RUN_ID) + + session.query(EdgeLogsModel).delete() + session.commit() + + def test_logfile_path(self, session: Session): + p: str = logfile_path(dag_id=DAG_ID, task_id=TASK_ID, run_id=RUN_ID, try_number=1, map_index=-1) + assert p + assert f"dag_id={DAG_ID}/run_id={RUN_ID}/task_id={TASK_ID}/attempt=1" in p + assert "/-1" not in p + + def test_push_logs(self, session: Session): + log_data = PushLogsBody( + log_chunk_data="This is Lorem Ipsum log data", log_chunk_time=timezone.utcnow() + ) + push_logs( + dag_id=DAG_ID, + task_id=TASK_ID, + run_id=RUN_ID, + try_number=1, + map_index=-1, + body=log_data, + ) + logs: list[EdgeLogsModel] = session.query(EdgeLogsModel).all() + assert len(logs) == 1 + assert logs[0].dag_id == DAG_ID + assert logs[0].task_id == TASK_ID + assert logs[0].run_id == RUN_ID + assert logs[0].try_number == 1 + assert logs[0].map_index == -1 + assert "Lorem Ipsum" in logs[0].log_chunk_data From 9a9348951594429d11ef407e9354e99bea938d96 Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Sat, 30 Nov 2024 18:13:40 +0100 Subject: [PATCH 2/3] Review feedback, use SessionDep from FastAPI --- .../edge/worker_api/routes/_v2_routes.py | 3 ++- .../providers/edge/worker_api/routes/logs.py | 26 +++++++++---------- .../tests/edge/worker_api/routes/test_logs.py | 19 ++++++++------ 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py b/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py index f155c63e96a5b..128500f63c621 100644 --- a/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py +++ b/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py @@ -298,6 +298,7 @@ def push_logs_v2( request_obj = PushLogsBody( log_chunk_data=body["log_chunk_data"], log_chunk_time=body["log_chunk_time"] ) - push_logs(dag_id, task_id, run_id, try_number, int(map_index), request_obj) + with create_session() as session: + push_logs(dag_id, task_id, run_id, try_number, int(map_index), request_obj, session) except HTTPException as e: return e.to_response() # type: ignore[attr-defined] diff --git a/providers/src/airflow/providers/edge/worker_api/routes/logs.py b/providers/src/airflow/providers/edge/worker_api/routes/logs.py index 1e5f9be7db677..3dc04a6670959 100644 --- a/providers/src/airflow/providers/edge/worker_api/routes/logs.py +++ b/providers/src/airflow/providers/edge/worker_api/routes/logs.py @@ -31,10 +31,11 @@ AirflowRouter, Body, Depends, + SessionDep, create_openapi_http_exception_doc, status, ) -from airflow.utils.session import NEW_SESSION, create_session, provide_session +from airflow.utils.session import NEW_SESSION, provide_session logs_router = AirflowRouter(tags=["Logs"], prefix="/logs") @@ -104,20 +105,19 @@ def push_logs( description="The worker remote has no access to log sink and with this can send log chunks to the central site.", ), ], + session: SessionDep, ) -> None: """Push an incremental log chunk from Edge Worker to central site.""" - with create_session() as session: - log_chunk = EdgeLogsModel( - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - try_number=try_number, - log_chunk_time=body.log_chunk_time, - log_chunk_data=body.log_chunk_data, - ) - session.add(log_chunk) - session.commit() + log_chunk = EdgeLogsModel( + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + try_number=try_number, + log_chunk_time=body.log_chunk_time, + log_chunk_data=body.log_chunk_data, + ) + session.add(log_chunk) # Write logs to local file to make them accessible task = TaskInstanceKey( dag_id=dag_id, task_id=task_id, run_id=run_id, try_number=try_number, map_index=map_index diff --git a/providers/tests/edge/worker_api/routes/test_logs.py b/providers/tests/edge/worker_api/routes/test_logs.py index 84090192b38c7..75380172cb436 100644 --- a/providers/tests/edge/worker_api/routes/test_logs.py +++ b/providers/tests/edge/worker_api/routes/test_logs.py @@ -25,6 +25,7 @@ from airflow.providers.edge.worker_api.datamodels import PushLogsBody from airflow.providers.edge.worker_api.routes.logs import logfile_path, push_logs from airflow.utils import timezone +from airflow.utils.session import create_session if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -57,14 +58,16 @@ def test_push_logs(self, session: Session): log_data = PushLogsBody( log_chunk_data="This is Lorem Ipsum log data", log_chunk_time=timezone.utcnow() ) - push_logs( - dag_id=DAG_ID, - task_id=TASK_ID, - run_id=RUN_ID, - try_number=1, - map_index=-1, - body=log_data, - ) + with create_session(session) as session: + push_logs( + dag_id=DAG_ID, + task_id=TASK_ID, + run_id=RUN_ID, + try_number=1, + map_index=-1, + body=log_data, + session=session, + ) logs: list[EdgeLogsModel] = session.query(EdgeLogsModel).all() assert len(logs) == 1 assert logs[0].dag_id == DAG_ID From 2311429612bf9f9f85c0658dc6e1a22185e4bd67 Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Sat, 30 Nov 2024 18:36:27 +0100 Subject: [PATCH 3/3] Fix pytest --- providers/tests/edge/worker_api/routes/test_logs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/tests/edge/worker_api/routes/test_logs.py b/providers/tests/edge/worker_api/routes/test_logs.py index 75380172cb436..45da5fb02bca1 100644 --- a/providers/tests/edge/worker_api/routes/test_logs.py +++ b/providers/tests/edge/worker_api/routes/test_logs.py @@ -58,7 +58,7 @@ def test_push_logs(self, session: Session): log_data = PushLogsBody( log_chunk_data="This is Lorem Ipsum log data", log_chunk_time=timezone.utcnow() ) - with create_session(session) as session: + with create_session() as session: push_logs( dag_id=DAG_ID, task_id=TASK_ID,