diff --git a/airflow/providers/elasticsearch/log/es_response.py b/airflow/providers/elasticsearch/log/es_response.py index ce11c715aacb5..6d8a3aeac366c 100644 --- a/airflow/providers/elasticsearch/log/es_response.py +++ b/airflow/providers/elasticsearch/log/es_response.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +from typing import Iterator + def _wrap(val): if isinstance(val, dict): @@ -117,7 +119,7 @@ def __init__(self, search, response, doc_class=None): super().__setattr__("_doc_class", doc_class) super().__init__(response) - def __iter__(self): + def __iter__(self) -> Iterator[Hit]: return iter(self.hits) def __getitem__(self, key): @@ -129,7 +131,7 @@ def __bool__(self): return bool(self.hits) @property - def hits(self): + def hits(self) -> list[Hit]: """ This property provides access to the hits (i.e., the results) of the Elasticsearch response. diff --git a/airflow/providers/elasticsearch/log/es_task_handler.py b/airflow/providers/elasticsearch/log/es_task_handler.py index 6c97e25f3b691..33a323a95895b 100644 --- a/airflow/providers/elasticsearch/log/es_task_handler.py +++ b/airflow/providers/elasticsearch/log/es_task_handler.py @@ -25,14 +25,13 @@ import warnings from collections import defaultdict from operator import attrgetter -from typing import TYPE_CHECKING, Any, Callable, List, Tuple +from typing import TYPE_CHECKING, Any, Callable, List, Literal, Tuple from urllib.parse import quote, urlparse # Using `from elasticsearch import *` would break elasticsearch mocking used in unit test. import elasticsearch import pendulum from elasticsearch.exceptions import NotFoundError -from typing_extensions import Literal from airflow.configuration import conf from airflow.exceptions import AirflowProviderDeprecationWarning @@ -136,7 +135,7 @@ def __init__( super().__init__(base_log_folder, filename_template) self.closed = False - self.client = elasticsearch.Elasticsearch(host, **es_kwargs) # type: ignore[attr-defined] + self.client = elasticsearch.Elasticsearch(host, **es_kwargs) # in airflow.cfg, host of elasticsearch has to be http://dockerhostXxxx:9200 if USE_PER_RUN_LOG_ID and log_id_template is not None: warnings.warn( @@ -237,12 +236,11 @@ def _clean_date(value: datetime | None) -> str: return "" return value.strftime("%Y_%m_%dT%H_%M_%S_%f") - def _group_logs_by_host(self, logs): + def _group_logs_by_host(self, response: ElasticSearchResponse) -> dict[str, list[Hit]]: grouped_logs = defaultdict(list) - for log in logs: - key = getattr_nested(log, self.host_field, None) or "default_host" - grouped_logs[key].append(log) - + for hit in response: + key = getattr_nested(hit, self.host_field, None) or "default_host" + grouped_logs[key].append(hit) return grouped_logs def _read_grouped_logs(self): @@ -267,9 +265,14 @@ def _read( offset = metadata["offset"] log_id = self._render_log_id(ti, try_number) - logs = self._es_read(log_id, offset) - logs_by_host = self._group_logs_by_host(logs) - next_offset = offset if not logs else attrgetter(self.offset_field)(logs[-1]) + response = self._es_read(log_id, offset) + if response is not None and response.hits: + logs_by_host = self._group_logs_by_host(response) + next_offset = attrgetter(self.offset_field)(response[-1]) + else: + logs_by_host = None + next_offset = offset + # Ensure a string here. Large offset numbers will get JSON.parsed incorrectly # on the client. Sending as a string prevents this issue. # https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Number/MAX_SAFE_INTEGER @@ -278,8 +281,9 @@ def _read( # end_of_log_mark may contain characters like '\n' which is needed to # have the log uploaded but will not be stored in elasticsearch. metadata["end_of_log"] = False - if any(x[-1].message == self.end_of_log_mark for x in logs_by_host.values()): - metadata["end_of_log"] = True + if logs_by_host: + if any(x[-1].message == self.end_of_log_mark for x in logs_by_host.values()): + metadata["end_of_log"] = True cur_ts = pendulum.now() if "last_log_timestamp" in metadata: @@ -308,27 +312,30 @@ def _read( # If we hit the end of the log, remove the actual end_of_log message # to prevent it from showing in the UI. - def concat_logs(lines): - log_range = (len(lines) - 1) if lines[-1].message == self.end_of_log_mark else len(lines) - return "\n".join(self._format_msg(lines[i]) for i in range(log_range)) + def concat_logs(hits: list[Hit]): + log_range = (len(hits) - 1) if hits[-1].message == self.end_of_log_mark else len(hits) + return "\n".join(self._format_msg(hits[i]) for i in range(log_range)) - message = [(host, concat_logs(hosted_log)) for host, hosted_log in logs_by_host.items()] + if logs_by_host: + message = [(host, concat_logs(hits)) for host, hits in logs_by_host.items()] + else: + message = [] return message, metadata - def _format_msg(self, log_line): + def _format_msg(self, hit: Hit): """Format ES Record to match settings.LOG_FORMAT when used with json_format.""" # Using formatter._style.format makes it future proof i.e. # if we change the formatter style from '%' to '{' or '$', this will still work if self.json_format: with contextlib.suppress(Exception): return self.formatter._style.format( - logging.makeLogRecord({**LOG_LINE_DEFAULTS, **log_line.to_dict()}) + logging.makeLogRecord({**LOG_LINE_DEFAULTS, **hit.to_dict()}) ) # Just a safe-guard to preserve backwards-compatibility - return log_line.message + return hit.message - def _es_read(self, log_id: str, offset: int | str) -> list | ElasticSearchResponse: + def _es_read(self, log_id: str, offset: int | str) -> ElasticSearchResponse | None: """ Return the logs matching log_id in Elasticsearch and next offset or ''. @@ -352,7 +359,6 @@ def _es_read(self, log_id: str, offset: int | str) -> list | ElasticSearchRespon self.log.exception("The target index pattern %s does not exist", self.index_patterns) raise e - logs: list[Any] | ElasticSearchResponse = [] if max_log_line != 0: try: query.update({"sort": [self.offset_field]}) @@ -362,11 +368,11 @@ def _es_read(self, log_id: str, offset: int | str) -> list | ElasticSearchRespon size=self.MAX_LINE_PER_PAGE, from_=self.MAX_LINE_PER_PAGE * self.PAGE, ) - logs = ElasticSearchResponse(self, res) + return ElasticSearchResponse(self, res) except Exception as err: self.log.exception("Could not read log with log_id: %s. Exception: %s", log_id, err) - return logs + return None def emit(self, record): if self.handler: