diff --git a/airflow/providers/apache/livy/hooks/livy.py b/airflow/providers/apache/livy/hooks/livy.py index a3da405e788a5..012d63b0cf53a 100644 --- a/airflow/providers/apache/livy/hooks/livy.py +++ b/airflow/providers/apache/livy/hooks/livy.py @@ -17,15 +17,20 @@ """This module contains the Apache Livy hook.""" from __future__ import annotations +import asyncio import json import re from enum import Enum from typing import Any, Sequence +import aiohttp import requests +from aiohttp import ClientResponseError +from asgiref.sync import sync_to_async from airflow.exceptions import AirflowException -from airflow.providers.http.hooks.http import HttpHook +from airflow.models import Connection +from airflow.providers.http.hooks.http import HttpAsyncHook, HttpHook from airflow.utils.log.logging_mixin import LoggingMixin @@ -444,3 +449,386 @@ def _validate_extra_conf(conf: dict[Any, Any]) -> bool: if any(True for k, v in conf.items() if not (v and isinstance(v, str) or isinstance(v, int))): raise ValueError("'conf' values must be either strings or ints") return True + + +class LivyAsyncHook(HttpAsyncHook, LoggingMixin): + """ + Hook for Apache Livy through the REST API asynchronously + + :param livy_conn_id: reference to a pre-defined Livy Connection. + :param extra_options: A dictionary of options passed to Livy. + :param extra_headers: A dictionary of headers passed to the HTTP request to livy. + + .. seealso:: + For more details refer to the Apache Livy API reference: + https://livy.apache.org/docs/latest/rest-api.html + """ + + TERMINAL_STATES = { + BatchState.SUCCESS, + BatchState.DEAD, + BatchState.KILLED, + BatchState.ERROR, + } + + _def_headers = {"Content-Type": "application/json", "Accept": "application/json"} + + conn_name_attr = "livy_conn_id" + default_conn_name = "livy_default" + conn_type = "livy" + hook_name = "Apache Livy" + + def __init__( + self, + livy_conn_id: str = default_conn_name, + extra_options: dict[str, Any] | None = None, + extra_headers: dict[str, Any] | None = None, + ) -> None: + super().__init__(http_conn_id=livy_conn_id) + self.extra_headers = extra_headers or {} + self.extra_options = extra_options or {} + + async def _do_api_call_async( + self, + endpoint: str | None = None, + data: dict[str, Any] | str | None = None, + headers: dict[str, Any] | None = None, + extra_options: dict[str, Any] | None = None, + ) -> Any: + """ + Performs an asynchronous HTTP request call + + :param endpoint: the endpoint to be called i.e. resource/v1/query? + :param data: payload to be uploaded or request parameters + :param headers: additional headers to be passed through as a dictionary + :param extra_options: Additional kwargs to pass when creating a request. + For example, ``run(json=obj)`` is passed as ``aiohttp.ClientSession().get(json=obj)`` + """ + extra_options = extra_options or {} + + # headers may be passed through directly or in the "extra" field in the connection + # definition + _headers = {} + auth = None + + if self.http_conn_id: + conn = await sync_to_async(self.get_connection)(self.http_conn_id) + + self.base_url = self._generate_base_url(conn) + if conn.login: + auth = self.auth_type(conn.login, conn.password) + if conn.extra: + try: + _headers.update(conn.extra_dejson) + except TypeError: + self.log.warning("Connection to %s has invalid extra field.", conn.host) + if headers: + _headers.update(headers) + + if self.base_url and not self.base_url.endswith("/") and endpoint and not endpoint.startswith("/"): + url = self.base_url + "/" + endpoint + else: + url = (self.base_url or "") + (endpoint or "") + + async with aiohttp.ClientSession() as session: + if self.method == "GET": + request_func = session.get + elif self.method == "POST": + request_func = session.post + elif self.method == "PATCH": + request_func = session.patch + else: + return {"Response": f"Unexpected HTTP Method: {self.method}", "status": "error"} + + attempt_num = 1 + while True: + response = await request_func( + url, + json=data if self.method in ("POST", "PATCH") else None, + params=data if self.method == "GET" else None, + headers=headers, + auth=auth, + **extra_options, + ) + try: + response.raise_for_status() + return await response.json() + except ClientResponseError as e: + self.log.warning( + "[Try %d of %d] Request to %s failed.", + attempt_num, + self.retry_limit, + url, + ) + if not self._retryable_error_async(e) or attempt_num == self.retry_limit: + self.log.exception("HTTP error, status code: %s", e.status) + # In this case, the user probably made a mistake. + # Don't retry. + return {"Response": {e.message}, "Status Code": {e.status}, "status": "error"} + + attempt_num += 1 + await asyncio.sleep(self.retry_delay) + + def _generate_base_url(self, conn: Connection) -> str: + if conn.host and "://" in conn.host: + base_url: str = conn.host + else: + # schema defaults to HTTP + schema = conn.schema if conn.schema else "http" + host = conn.host if conn.host else "" + base_url = f"{schema}://{host}" + if conn.port: + base_url = f"{base_url}:{conn.port}" + return base_url + + async def run_method( + self, + endpoint: str, + method: str = "GET", + data: Any | None = None, + headers: dict[str, Any] | None = None, + ) -> Any: + """ + Wrapper for HttpAsyncHook, allows to change method on the same HttpAsyncHook + + :param method: http method + :param endpoint: endpoint + :param data: request payload + :param headers: headers + :return: http response + """ + if method not in ("GET", "POST", "PUT", "DELETE", "HEAD"): + return {"status": "error", "response": f"Invalid http method {method}"} + + back_method = self.method + self.method = method + try: + result = await self._do_api_call_async(endpoint, data, headers, self.extra_options) + finally: + self.method = back_method + return {"status": "success", "response": result} + + async def get_batch_state(self, session_id: int | str) -> Any: + """ + Fetch the state of the specified batch asynchronously. + + :param session_id: identifier of the batch sessions + :return: batch state + """ + self._validate_session_id(session_id) + self.log.info("Fetching info for batch session %d", session_id) + result = await self.run_method(endpoint=f"/batches/{session_id}/state") + if result["status"] == "error": + self.log.info(result) + return {"batch_state": "error", "response": result, "status": "error"} + + if "state" not in result["response"]: + self.log.info( + "batch_state: error with as it is unable to get state for batch with id: %s", session_id + ) + return { + "batch_state": "error", + "response": f"Unable to get state for batch with id: {session_id}", + "status": "error", + } + + self.log.info("Successfully fetched the batch state.") + return { + "batch_state": BatchState(result["response"]["state"]), + "response": "successfully fetched the batch state.", + "status": "success", + } + + async def get_batch_logs( + self, session_id: int | str, log_start_position: int, log_batch_size: int + ) -> Any: + """ + Gets the session logs for a specified batch asynchronously. + + :param session_id: identifier of the batch sessions + :param log_start_position: Position from where to pull the logs + :param log_batch_size: Number of lines to pull in one batch + :return: response body + """ + self._validate_session_id(session_id) + log_params = {"from": log_start_position, "size": log_batch_size} + result = await self.run_method(endpoint=f"/batches/{session_id}/log", data=log_params) + if result["status"] == "error": + self.log.info(result) + return {"response": result["response"], "status": "error"} + return {"response": result["response"], "status": "success"} + + async def dump_batch_logs(self, session_id: int | str) -> Any: + """ + Dumps the session logs for a specified batch asynchronously + + :param session_id: identifier of the batch sessions + :return: response body + """ + self.log.info("Fetching the logs for batch session with id: %d", session_id) + log_start_line = 0 + log_total_lines = 0 + log_batch_size = 100 + + while log_start_line <= log_total_lines: + # Livy log endpoint is paginated. + result = await self.get_batch_logs(session_id, log_start_line, log_batch_size) + if result["status"] == "success": + log_start_line += log_batch_size + log_lines = self._parse_request_response(result["response"], "log") + for log_line in log_lines: + self.log.info(log_line) + return log_lines + else: + self.log.info(result["response"]) + return result["response"] + + @staticmethod + def _validate_session_id(session_id: int | str) -> None: + """ + Validate session id is a int + + :param session_id: session id + """ + try: + int(session_id) + except (TypeError, ValueError): + raise TypeError("'session_id' must be an integer") + + @staticmethod + def _parse_post_response(response: dict[Any, Any]) -> Any: + """ + Parse batch response for batch id + + :param response: response body + :return: session id + """ + return response.get("id") + + @staticmethod + def _parse_request_response(response: dict[Any, Any], parameter: Any) -> Any: + """ + Parse batch response for batch id + + :param response: response body + :return: value of parameter + """ + return response.get(parameter) + + @staticmethod + def build_post_batch_body( + file: str, + args: Sequence[str | int | float] | None = None, + class_name: str | None = None, + jars: list[str] | None = None, + py_files: list[str] | None = None, + files: list[str] | None = None, + archives: list[str] | None = None, + name: str | None = None, + driver_memory: str | None = None, + driver_cores: int | str | None = None, + executor_memory: str | None = None, + executor_cores: int | None = None, + num_executors: int | str | None = None, + queue: str | None = None, + proxy_user: str | None = None, + conf: dict[Any, Any] | None = None, + ) -> dict[str, Any]: + """ + Build the post batch request body. + + :param file: Path of the file containing the application to execute (required). + :param proxy_user: User to impersonate when running the job. + :param class_name: Application Java/Spark main class string. + :param args: Command line arguments for the application s. + :param jars: jars to be used in this sessions. + :param py_files: Python files to be used in this session. + :param files: files to be used in this session. + :param driver_memory: Amount of memory to use for the driver process string. + :param driver_cores: Number of cores to use for the driver process int. + :param executor_memory: Amount of memory to use per executor process string. + :param executor_cores: Number of cores to use for each executor int. + :param num_executors: Number of executors to launch for this session int. + :param archives: Archives to be used in this session. + :param queue: The name of the YARN queue to which submitted string. + :param name: The name of this session string. + :param conf: Spark configuration properties. + :return: request body + """ + body: dict[str, Any] = {"file": file} + + if proxy_user: + body["proxyUser"] = proxy_user + if class_name: + body["className"] = class_name + if args and LivyAsyncHook._validate_list_of_stringables(args): + body["args"] = [str(val) for val in args] + if jars and LivyAsyncHook._validate_list_of_stringables(jars): + body["jars"] = jars + if py_files and LivyAsyncHook._validate_list_of_stringables(py_files): + body["pyFiles"] = py_files + if files and LivyAsyncHook._validate_list_of_stringables(files): + body["files"] = files + if driver_memory and LivyAsyncHook._validate_size_format(driver_memory): + body["driverMemory"] = driver_memory + if driver_cores: + body["driverCores"] = driver_cores + if executor_memory and LivyAsyncHook._validate_size_format(executor_memory): + body["executorMemory"] = executor_memory + if executor_cores: + body["executorCores"] = executor_cores + if num_executors: + body["numExecutors"] = num_executors + if archives and LivyAsyncHook._validate_list_of_stringables(archives): + body["archives"] = archives + if queue: + body["queue"] = queue + if name: + body["name"] = name + if conf and LivyAsyncHook._validate_extra_conf(conf): + body["conf"] = conf + + return body + + @staticmethod + def _validate_size_format(size: str) -> bool: + """ + Validate size format. + + :param size: size value + :return: true if valid format + """ + if size and not (isinstance(size, str) and re.match(r"^\d+[kmgt]b?$", size, re.IGNORECASE)): + raise ValueError(f"Invalid java size format for string'{size}'") + return True + + @staticmethod + def _validate_list_of_stringables(vals: Sequence[str | int | float]) -> bool: + """ + Check the values in the provided list can be converted to strings. + + :param vals: list to validate + :return: true if valid + """ + if ( + vals is None + or not isinstance(vals, (tuple, list)) + or any(1 for val in vals if not isinstance(val, (str, int, float))) + ): + raise ValueError("List of strings expected") + return True + + @staticmethod + def _validate_extra_conf(conf: dict[Any, Any]) -> bool: + """ + Check configuration values are either strings or ints. + + :param conf: configuration variable + :return: true if valid + """ + if conf: + if not isinstance(conf, dict): + raise ValueError("'conf' argument must be a dict") + if any(True for k, v in conf.items() if not (v and isinstance(v, str) or isinstance(v, int))): + raise ValueError("'conf' values must be either strings or ints") + return True diff --git a/airflow/providers/apache/livy/operators/livy.py b/airflow/providers/apache/livy/operators/livy.py index 313f64c9f9afd..d842dec13ce10 100644 --- a/airflow/providers/apache/livy/operators/livy.py +++ b/airflow/providers/apache/livy/operators/livy.py @@ -23,6 +23,7 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.apache.livy.hooks.livy import BatchState, LivyHook +from airflow.providers.apache.livy.triggers.livy import LivyTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -56,6 +57,7 @@ class LivyOperator(BaseOperator): depends on the option that's being modified. :param extra_headers: A dictionary of headers passed to the HTTP request to livy. :param retry_args: Arguments which define the retry behaviour. + :param deferrable: Run operator in the deferrable mode See Tenacity documentation at https://github.com/jd/tenacity """ @@ -87,6 +89,7 @@ def __init__( extra_options: dict[str, Any] | None = None, extra_headers: dict[str, Any] | None = None, retry_args: dict[str, Any] | None = None, + deferrable: bool = False, **kwargs: Any, ) -> None: @@ -120,6 +123,7 @@ def __init__( self._livy_hook: LivyHook | None = None self._batch_id: int | str self.retry_args = retry_args + self.deferrable = deferrable def get_hook(self) -> LivyHook: """ @@ -138,13 +142,27 @@ def get_hook(self) -> LivyHook: def execute(self, context: Context) -> Any: self._batch_id = self.get_hook().post_batch(**self.spark_params) - - if self._polling_interval > 0: - self.poll_for_termination(self._batch_id) - - context["ti"].xcom_push(key="app_id", value=self.get_hook().get_batch(self._batch_id)["appId"]) - - return self._batch_id + self.log.info("Generated batch-id is %s", self._batch_id) + + # Wait for the job to complete + if not self.deferrable: + if self._polling_interval > 0: + self.poll_for_termination(self._batch_id) + context["ti"].xcom_push(key="app_id", value=self.get_hook().get_batch(self._batch_id)["appId"]) + return self._batch_id + + self.defer( + timeout=self.execution_timeout, + trigger=LivyTrigger( + batch_id=self._batch_id, + spark_params=self.spark_params, + livy_conn_id=self._livy_conn_id, + polling_interval=self._polling_interval, + extra_options=self._extra_options, + extra_headers=self._extra_headers, + ), + method_name="execute_complete", + ) def poll_for_termination(self, batch_id: int | str) -> None: """ @@ -170,3 +188,23 @@ def kill(self) -> None: """Delete the current batch session.""" if self._batch_id is not None: self.get_hook().delete_batch(self._batch_id) + + def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + # dump the logs from livy to worker through triggerer. + if event.get("log_lines", None) is not None: + for log_line in event["log_lines"]: + self.log.info(log_line) + + if event["status"] == "error": + raise AirflowException(event["response"]) + self.log.info( + "%s completed with response %s", + self.task_id, + event["response"], + ) + return event["batch_id"] diff --git a/airflow/providers/apache/livy/provider.yaml b/airflow/providers/apache/livy/provider.yaml index 8e69ef935d03a..bd95c1f3cf0ca 100644 --- a/airflow/providers/apache/livy/provider.yaml +++ b/airflow/providers/apache/livy/provider.yaml @@ -38,6 +38,8 @@ versions: dependencies: - apache-airflow>=2.3.0 - apache-airflow-providers-http + - aiohttp + - asgiref integrations: - integration-name: Apache Livy diff --git a/airflow/providers/apache/livy/triggers/__init__.py b/airflow/providers/apache/livy/triggers/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/apache/livy/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/apache/livy/triggers/livy.py b/airflow/providers/apache/livy/triggers/livy.py new file mode 100644 index 0000000000000..cfcbde53b9f5f --- /dev/null +++ b/airflow/providers/apache/livy/triggers/livy.py @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Apache Livy Trigger.""" +from __future__ import annotations + +import asyncio +from typing import Any, AsyncIterator + +from airflow.providers.apache.livy.hooks.livy import BatchState, LivyAsyncHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class LivyTrigger(BaseTrigger): + """ + Check for the state of a previously submitted job with batch_id + + :param batch_id: Batch job id + :param spark_params: Spark parameters; for example, + spark_params = {"file": "test/pi.py", "class_name": "org.apache.spark.examples.SparkPi", + "args": ["/usr/lib/spark/bin/run-example", "SparkPi", "10"],"jars": "command-runner.jar", + "driver_cores": 1, "executor_cores": 4,"num_executors": 1} + :param livy_conn_id: reference to a pre-defined Livy Connection. + :param polling_interval: time in seconds between polling for job completion. If poll_interval=0, in that + case return the batch_id and if polling_interval > 0, poll the livy job for termination in the + polling interval defined. + :param extra_options: A dictionary of options, where key is string and value + depends on the option that's being modified. + :param extra_headers: A dictionary of headers passed to the HTTP request to livy. + :param livy_hook_async: LivyAsyncHook object + """ + + def __init__( + self, + batch_id: int | str, + spark_params: dict[Any, Any], + livy_conn_id: str = "livy_default", + polling_interval: int = 0, + extra_options: dict[str, Any] | None = None, + extra_headers: dict[str, Any] | None = None, + livy_hook_async: LivyAsyncHook | None = None, + ): + super().__init__() + self._batch_id = batch_id + self.spark_params = spark_params + self._livy_conn_id = livy_conn_id + self._polling_interval = polling_interval + self._extra_options = extra_options + self._extra_headers = extra_headers + self._livy_hook_async = livy_hook_async + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes LivyTrigger arguments and classpath.""" + return ( + "airflow.providers.apache.livy.triggers.livy.LivyTrigger", + { + "batch_id": self._batch_id, + "spark_params": self.spark_params, + "livy_conn_id": self._livy_conn_id, + "polling_interval": self._polling_interval, + "extra_options": self._extra_options, + "extra_headers": self._extra_headers, + "livy_hook_async": self._livy_hook_async, + }, + ) + + async def run(self) -> AsyncIterator["TriggerEvent"]: + """ + Checks if the _polling_interval > 0, in that case it pools Livy for + batch termination asynchronously. + else returns the success response + """ + try: + if self._polling_interval > 0: + response = await self.poll_for_termination(self._batch_id) + yield TriggerEvent(response) + yield TriggerEvent( + { + "status": "success", + "batch_id": self._batch_id, + "response": f"Batch {self._batch_id} succeeded", + "log_lines": None, + } + ) + except Exception as exc: + yield TriggerEvent( + { + "status": "error", + "batch_id": self._batch_id, + "response": f"Batch {self._batch_id} did not succeed with {str(exc)}", + "log_lines": None, + } + ) + + async def poll_for_termination(self, batch_id: int | str) -> dict[str, Any]: + """ + Pool Livy for batch termination asynchronously. + + :param batch_id: id of the batch session to monitor. + """ + hook = self._get_async_hook() + state = await hook.get_batch_state(batch_id) + self.log.info("Batch with id %s is in state: %s", batch_id, state["batch_state"].value) + while state["batch_state"] not in hook.TERMINAL_STATES: + self.log.info("Batch with id %s is in state: %s", batch_id, state["batch_state"].value) + self.log.info("Sleeping for %s seconds", self._polling_interval) + await asyncio.sleep(self._polling_interval) + state = await hook.get_batch_state(batch_id) + self.log.info("Batch with id %s terminated with state: %s", batch_id, state["batch_state"].value) + log_lines = await hook.dump_batch_logs(batch_id) + if state["batch_state"] != BatchState.SUCCESS: + return { + "status": "error", + "batch_id": batch_id, + "response": f"Batch {batch_id} did not succeed", + "log_lines": log_lines, + } + return { + "status": "success", + "batch_id": batch_id, + "response": f"Batch {batch_id} succeeded", + "log_lines": log_lines, + } + + def _get_async_hook(self) -> LivyAsyncHook: + if self._livy_hook_async is None or not isinstance(self._livy_hook_async, LivyAsyncHook): + self._livy_hook_async = LivyAsyncHook( + livy_conn_id=self._livy_conn_id, + extra_headers=self._extra_headers, + extra_options=self._extra_options, + ) + return self._livy_hook_async diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 6a914f9ca514c..1af61088cf747 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -134,8 +134,10 @@ }, "apache.livy": { "deps": [ + "aiohttp", "apache-airflow-providers-http", - "apache-airflow>=2.3.0" + "apache-airflow>=2.3.0", + "asgiref" ], "cross-providers-deps": [ "http" diff --git a/tests/providers/apache/livy/compat.py b/tests/providers/apache/livy/compat.py new file mode 100644 index 0000000000000..af1f0d225713c --- /dev/null +++ b/tests/providers/apache/livy/compat.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +__all__ = ["async_mock", "AsyncMock"] + +import sys + +if sys.version_info < (3, 8): + # For compatibility with Python 3.7 + from asynctest import mock as async_mock + + # ``asynctest.mock.CoroutineMock`` which provide compatibility not working well with autospec=True + # as result "TypeError: object MagicMock can't be used in 'await' expression" could be raised. + # Best solution in this case provide as spec actual awaitable object + # >>> from tests.providers.apache.livy.compat import AsyncMock + # >>> from foo.bar import SpamEgg + # >>> mock_something = AsyncMock(SpamEgg) + from asynctest.mock import CoroutineMock as AsyncMock +else: + from unittest import mock as async_mock + from unittest.mock import AsyncMock diff --git a/tests/providers/apache/livy/hooks/test_livy.py b/tests/providers/apache/livy/hooks/test_livy.py index 1b70a72d251da..63266dfe2329f 100644 --- a/tests/providers/apache/livy/hooks/test_livy.py +++ b/tests/providers/apache/livy/hooks/test_livy.py @@ -19,15 +19,19 @@ import json from unittest.mock import MagicMock, patch +import multidict import pytest +from aiohttp import ClientResponseError, RequestInfo from requests.exceptions import RequestException from airflow.exceptions import AirflowException from airflow.models import Connection -from airflow.providers.apache.livy.hooks.livy import BatchState, LivyHook +from airflow.providers.apache.livy.hooks.livy import BatchState, LivyAsyncHook, LivyHook from airflow.utils import db +from tests.providers.apache.livy.compat import AsyncMock, async_mock from tests.test_utils.db import clear_db_connections +LIVY_CONN_ID = LivyHook.default_conn_name DEFAULT_CONN_ID = LivyHook.default_conn_name DEFAULT_HOST = "livy" DEFAULT_SCHEMA = "http" @@ -400,3 +404,389 @@ def test_alternate_auth_type(self): hook.get_conn() auth_type.assert_called_once_with("login", "secret") + + +class TestLivyAsyncHook: + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.run_method") + async def test_get_batch_state_running(self, mock_run_method): + """Asserts the batch state as running with success response.""" + mock_run_method.return_value = {"status": "success", "response": {"state": BatchState.RUNNING}} + hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) + state = await hook.get_batch_state(BATCH_ID) + assert state == { + "batch_state": BatchState.RUNNING, + "response": "successfully fetched the batch state.", + "status": "success", + } + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.run_method") + async def test_get_batch_state_error(self, mock_run_method): + """Asserts the batch state as error with error response.""" + mock_run_method.return_value = {"status": "error", "response": {"state": "error"}} + hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) + state = await hook.get_batch_state(BATCH_ID) + assert state["status"] == "error" + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.run_method") + async def test_get_batch_state_error_without_state(self, mock_run_method): + """Asserts the batch state as error without state returned as part of mock.""" + mock_run_method.return_value = {"status": "success", "response": {}} + hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) + state = await hook.get_batch_state(BATCH_ID) + assert state["status"] == "error" + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.run_method") + async def test_get_batch_logs_success(self, mock_run_method): + """Asserts the batch log as success.""" + mock_run_method.return_value = {"status": "success", "response": {}} + hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) + state = await hook.get_batch_logs(BATCH_ID, 0, 100) + assert state["status"] == "success" + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.run_method") + async def test_get_batch_logs_error(self, mock_run_method): + """Asserts the batch log for error.""" + mock_run_method.return_value = {"status": "error", "response": {}} + hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) + state = await hook.get_batch_logs(BATCH_ID, 0, 100) + assert state["status"] == "error" + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_batch_logs") + async def test_dump_batch_logs_success(self, mock_get_batch_logs): + """Asserts the log dump log for success response.""" + mock_get_batch_logs.return_value = { + "status": "success", + "response": {"id": 1, "log": ["mock_log_1", "mock_log_2", "mock_log_3"]}, + } + hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) + log_dump = await hook.dump_batch_logs(BATCH_ID) + assert log_dump == ["mock_log_1", "mock_log_2", "mock_log_3"] + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_batch_logs") + async def test_dump_batch_logs_error(self, mock_get_batch_logs): + """Asserts the log dump log for error response.""" + mock_get_batch_logs.return_value = { + "status": "error", + "response": {"id": 1, "log": ["mock_log_1", "mock_log_2"]}, + } + hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) + log_dump = await hook.dump_batch_logs(BATCH_ID) + assert log_dump == {"id": 1, "log": ["mock_log_1", "mock_log_2"]} + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook._do_api_call_async") + async def test_run_method_success(self, mock_do_api_call_async): + """Asserts the run_method for success response.""" + mock_do_api_call_async.return_value = {"status": "error", "response": {"id": 1}} + hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) + response = await hook.run_method("localhost", "GET") + assert response["status"] == "success" + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook._do_api_call_async") + async def test_run_method_error(self, mock_do_api_call_async): + """Asserts the run_method for error response.""" + mock_do_api_call_async.return_value = {"status": "error", "response": {"id": 1}} + hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) + response = await hook.run_method("localhost", "abc") + assert response == {"status": "error", "response": "Invalid http method abc"} + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection") + async def test_do_api_call_async_post_method_with_success(self, mock_get_connection, mock_session): + """Asserts the _do_api_call_async for success response for POST method.""" + + async def mock_fun(arg1, arg2, arg3, arg4): + return {"status": "success"} + + mock_session.return_value.__aexit__.return_value = mock_fun + mock_session.return_value.__aenter__.return_value.post = AsyncMock() + mock_session.return_value.__aenter__.return_value.post.return_value.json = AsyncMock( + return_value={"status": "success"} + ) + GET_RUN_ENDPOINT = "api/jobs/runs/get" + hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) + hook.http_conn_id = mock_get_connection + hook.http_conn_id.host = "https://localhost" + hook.http_conn_id.login = "login" + hook.http_conn_id.password = "PASSWORD" + response = await hook._do_api_call_async(GET_RUN_ENDPOINT) + assert response == {"status": "success"} + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection") + async def test_do_api_call_async_get_method_with_success(self, mock_get_connection, mock_session): + """Asserts the _do_api_call_async for GET method.""" + + async def mock_fun(arg1, arg2, arg3, arg4): + return {"status": "success"} + + mock_session.return_value.__aexit__.return_value = mock_fun + mock_session.return_value.__aenter__.return_value.get = AsyncMock() + mock_session.return_value.__aenter__.return_value.get.return_value.json = AsyncMock( + return_value={"status": "success"} + ) + GET_RUN_ENDPOINT = "api/jobs/runs/get" + hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) + hook.method = "GET" + hook.http_conn_id = mock_get_connection + hook.http_conn_id.host = "test.com" + hook.http_conn_id.login = "login" + hook.http_conn_id.password = "PASSWORD" + hook.http_conn_id.extra_dejson = "" + response = await hook._do_api_call_async(GET_RUN_ENDPOINT) + assert response == {"status": "success"} + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection") + async def test_do_api_call_async_patch_method_with_success(self, mock_get_connection, mock_session): + """Asserts the _do_api_call_async for PATCH method.""" + + async def mock_fun(arg1, arg2, arg3, arg4): + return {"status": "success"} + + mock_session.return_value.__aexit__.return_value = mock_fun + mock_session.return_value.__aenter__.return_value.patch = AsyncMock() + mock_session.return_value.__aenter__.return_value.patch.return_value.json = AsyncMock( + return_value={"status": "success"} + ) + GET_RUN_ENDPOINT = "api/jobs/runs/get" + hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) + hook.method = "PATCH" + hook.http_conn_id = mock_get_connection + hook.http_conn_id.host = "test.com" + hook.http_conn_id.login = "login" + hook.http_conn_id.password = "PASSWORD" + hook.http_conn_id.extra_dejson = "" + response = await hook._do_api_call_async(GET_RUN_ENDPOINT) + assert response == {"status": "success"} + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection") + async def test_do_api_call_async_unexpected_method_error(self, mock_get_connection, mock_session): + """Asserts the _do_api_call_async for unexpected method error""" + GET_RUN_ENDPOINT = "api/jobs/runs/get" + hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) + hook.method = "abc" + hook.http_conn_id = mock_get_connection + hook.http_conn_id.host = "test.com" + hook.http_conn_id.login = "login" + hook.http_conn_id.password = "PASSWORD" + hook.http_conn_id.extra_dejson = "" + response = await hook._do_api_call_async(endpoint=GET_RUN_ENDPOINT, headers={}) + assert response == {"Response": "Unexpected HTTP Method: abc", "status": "error"} + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection") + async def test_do_api_call_async_with_type_error(self, mock_get_connection, mock_session): + """Asserts the _do_api_call_async for TypeError.""" + + async def mock_fun(arg1, arg2, arg3, arg4): + return {"random value"} + + mock_session.return_value.__aexit__.return_value = mock_fun + mock_session.return_value.__aenter__.return_value.patch.return_value.json.return_value = {} + hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) + hook.method = "PATCH" + hook.retry_limit = 1 + hook.retry_delay = 1 + hook.http_conn_id = mock_get_connection + with pytest.raises(TypeError): + await hook._do_api_call_async(endpoint="", data="test", headers=mock_fun, extra_options=mock_fun) + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection") + async def test_do_api_call_async_with_client_response_error(self, mock_get_connection, mock_session): + """Asserts the _do_api_call_async for Client Response Error.""" + + async def mock_fun(arg1, arg2, arg3, arg4): + return {"random value"} + + mock_session.return_value.__aexit__.return_value = mock_fun + mock_session.return_value.__aenter__.return_value.patch = AsyncMock() + mock_session.return_value.__aenter__.return_value.patch.return_value.json.side_effect = ( + ClientResponseError( + request_info=RequestInfo(url="example.com", method="PATCH", headers=multidict.CIMultiDict()), + status=500, + history=[], + ) + ) + GET_RUN_ENDPOINT = "" + hook = LivyAsyncHook(livy_conn_id="livy_default") + hook.method = "PATCH" + hook.base_url = "" + hook.http_conn_id = mock_get_connection + hook.http_conn_id.host = "test.com" + hook.http_conn_id.login = "login" + hook.http_conn_id.password = "PASSWORD" + hook.http_conn_id.extra_dejson = "" + response = await hook._do_api_call_async(GET_RUN_ENDPOINT) + assert response["status"] == "error" + + def set_conn(self): + db.merge_conn( + Connection(conn_id=LIVY_CONN_ID, conn_type="http", host="host", schema="http", port=8998) + ) + db.merge_conn(Connection(conn_id="default_port", conn_type="http", host="http://host")) + db.merge_conn(Connection(conn_id="default_protocol", conn_type="http", host="host")) + db.merge_conn(Connection(conn_id="port_set", host="host", conn_type="http", port=1234)) + db.merge_conn(Connection(conn_id="schema_set", host="host", conn_type="http", schema="zzz")) + db.merge_conn( + Connection(conn_id="dont_override_schema", conn_type="http", host="http://host", schema="zzz") + ) + db.merge_conn(Connection(conn_id="missing_host", conn_type="http", port=1234)) + db.merge_conn(Connection(conn_id="invalid_uri", uri="http://invalid_uri:4321")) + + def test_build_get_hook(self): + self.set_conn() + connection_url_mapping = { + # id, expected + "default_port": "http://host", + "default_protocol": "http://host", + "port_set": "http://host:1234", + "schema_set": "zzz://host", + "dont_override_schema": "http://host", + } + + for conn_id, expected in connection_url_mapping.items(): + hook = LivyAsyncHook(livy_conn_id=conn_id) + response_conn: Connection = hook.get_connection(conn_id=conn_id) + assert isinstance(response_conn, Connection) + assert hook._generate_base_url(response_conn) == expected + + def test_build_body(self): + # minimal request + body = LivyAsyncHook.build_post_batch_body(file="appname") + + assert body == {"file": "appname"} + + # complex request + body = LivyAsyncHook.build_post_batch_body( + file="appname", + class_name="org.example.livy", + proxy_user="proxyUser", + args=["a", "1"], + jars=["jar1", "jar2"], + files=["file1", "file2"], + py_files=["py1", "py2"], + archives=["arch1", "arch2"], + queue="queue", + name="name", + conf={"a": "b"}, + driver_cores=2, + driver_memory="1M", + executor_memory="1m", + executor_cores="1", + num_executors="10", + ) + + assert body == { + "file": "appname", + "className": "org.example.livy", + "proxyUser": "proxyUser", + "args": ["a", "1"], + "jars": ["jar1", "jar2"], + "files": ["file1", "file2"], + "pyFiles": ["py1", "py2"], + "archives": ["arch1", "arch2"], + "queue": "queue", + "name": "name", + "conf": {"a": "b"}, + "driverCores": 2, + "driverMemory": "1M", + "executorMemory": "1m", + "executorCores": "1", + "numExecutors": "10", + } + + def test_parameters_validation(self): + with pytest.raises(ValueError): + LivyAsyncHook.build_post_batch_body(file="appname", executor_memory="xxx") + + assert LivyAsyncHook.build_post_batch_body(file="appname", args=["a", 1, 0.1])["args"] == [ + "a", + "1", + "0.1", + ] + + def test_parse_post_response(self): + res_id = LivyAsyncHook._parse_post_response({"id": BATCH_ID, "log": []}) + + assert BATCH_ID == res_id + + @pytest.mark.parametrize("valid_size", ["1m", "1mb", "1G", "1GB", "1Gb", None]) + def test_validate_size_format_success(self, valid_size): + assert LivyAsyncHook._validate_size_format(valid_size) + + @pytest.mark.parametrize("invalid_size", ["1Gb foo", "10", 1]) + def test_validate_size_format_failure(self, invalid_size): + with pytest.raises(ValueError): + assert LivyAsyncHook._validate_size_format(invalid_size) + + @pytest.mark.parametrize( + "valid_string", + [ + [1, "string"], + (1, "string"), + [], + ], + ) + def test_validate_list_of_stringables_success(self, valid_string): + assert LivyAsyncHook._validate_list_of_stringables(valid_string) + + @pytest.mark.parametrize("invalid_string", [{"a": "a"}, [1, {}], [1, None], None, 1, "string"]) + def test_validate_list_of_stringables_failure(self, invalid_string): + with pytest.raises(ValueError): + LivyAsyncHook._validate_list_of_stringables(invalid_string) + + @pytest.mark.parametrize( + "conf", + [ + {"k1": "v1", "k2": 0}, + {}, + None, + ], + ) + def test_validate_extra_conf_success(self, conf): + assert LivyAsyncHook._validate_extra_conf(conf) + + @pytest.mark.parametrize( + "conf", + [ + "k1=v1", + [("k1", "v1"), ("k2", 0)], + {"outer": {"inner": "val"}}, + {"has_val": "val", "no_val": None}, + {"has_val": "val", "no_val": ""}, + ], + ) + def test_validate_extra_conf_failure(self, conf): + with pytest.raises(ValueError): + LivyAsyncHook._validate_extra_conf(conf) + + def test_parse_request_response(self): + assert BATCH_ID == LivyAsyncHook._parse_request_response( + response={"id": BATCH_ID, "log": []}, parameter="id" + ) + + @pytest.mark.parametrize("conn_id", [100, 0]) + def test_check_session_id_success(self, conn_id): + assert LivyAsyncHook._validate_session_id(conn_id) is None + + @pytest.mark.parametrize("conn_id", [None, "asd"]) + def test_check_session_id_failure(self, conn_id): + with pytest.raises(TypeError): + LivyAsyncHook._validate_session_id(None) diff --git a/tests/providers/apache/livy/operators/test_livy.py b/tests/providers/apache/livy/operators/test_livy.py index eb305e1c2d7de..e902febf08469 100644 --- a/tests/providers/apache/livy/operators/test_livy.py +++ b/tests/providers/apache/livy/operators/test_livy.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import Connection from airflow.models.dag import DAG from airflow.providers.apache.livy.hooks.livy import BatchState, LivyHook @@ -180,3 +180,157 @@ def test_log_dump(self, mock_get_batch, mock_post, mock_get_logs, mock_get, capl mock_get.assert_called_once_with(BATCH_ID, retry_args=None) mock_get_logs.assert_called_once_with(BATCH_ID, 0, 100) + + @patch( + "airflow.providers.apache.livy.operators.livy.LivyHook.dump_batch_logs", + return_value=None, + ) + @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state") + def test_poll_for_termination_deferrable(self, mock_livy, mock_dump_logs): + state_list = 2 * [BatchState.RUNNING] + [BatchState.SUCCESS] + + def side_effect(_, retry_args): + if state_list: + return state_list.pop(0) + # fail if does not stop right before + raise AssertionError() + + mock_livy.side_effect = side_effect + + task = LivyOperator( + file="sparkapp", polling_interval=1, dag=self.dag, task_id="livy_example", deferrable=True + ) + task._livy_hook = task.get_hook() + task.poll_for_termination(BATCH_ID) + + mock_livy.assert_called_with(BATCH_ID, retry_args=None) + mock_dump_logs.assert_called_with(BATCH_ID) + assert mock_livy.call_count == 3 + + @patch( + "airflow.providers.apache.livy.operators.livy.LivyHook.dump_batch_logs", + return_value=None, + ) + @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state") + def test_poll_for_termination_fail_deferrable(self, mock_livy, mock_dump_logs): + state_list = 2 * [BatchState.RUNNING] + [BatchState.ERROR] + + def side_effect(_, retry_args): + if state_list: + return state_list.pop(0) + # fail if does not stop right before + raise AssertionError() + + mock_livy.side_effect = side_effect + + task = LivyOperator( + file="sparkapp", polling_interval=1, dag=self.dag, task_id="livy_example", deferrable=True + ) + task._livy_hook = task.get_hook() + + with pytest.raises(AirflowException): + task.poll_for_termination(BATCH_ID) + + mock_livy.assert_called_with(BATCH_ID, retry_args=None) + mock_dump_logs.assert_called_with(BATCH_ID) + assert mock_livy.call_count == 3 + + @patch( + "airflow.providers.apache.livy.operators.livy.LivyHook.dump_batch_logs", + return_value=None, + ) + @patch( + "airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state", + return_value=BatchState.SUCCESS, + ) + @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) + @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value=GET_BATCH) + def test_execution_deferrable(self, mock_get_batch, mock_post, mock_get, mock_dump_logs): + task = LivyOperator( + livy_conn_id="livyunittest", + file="sparkapp", + polling_interval=1, + dag=self.dag, + task_id="livy_example", + deferrable=True, + ) + with pytest.raises(TaskDeferred): + task.execute(context=self.mock_context) + + call_args = {k: v for k, v in mock_post.call_args[1].items() if v} + assert call_args == {"file": "sparkapp"} + mock_get.assert_called_once_with(BATCH_ID, retry_args=None) + mock_dump_logs.assert_called_once_with(BATCH_ID) + mock_get_batch.assert_called_once_with(BATCH_ID) + self.mock_context["ti"].xcom_push.assert_called_once_with(key="app_id", value=APP_ID) + + @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch") + @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value=GET_BATCH) + def test_execution_with_extra_options_deferrable(self, mock_get_batch, mock_post): + extra_options = {"check_response": True} + task = LivyOperator( + file="sparkapp", + dag=self.dag, + task_id="livy_example", + extra_options=extra_options, + deferrable=True, + ) + + with pytest.raises(TaskDeferred): + task.execute(context=self.mock_context) + + assert task.get_hook().extra_options == extra_options + + @patch("airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch") + @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) + @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value=GET_BATCH) + def test_deletion_deferrable(self, mock_get_batch, mock_post, mock_delete): + task = LivyOperator( + livy_conn_id="livyunittest", + file="sparkapp", + dag=self.dag, + task_id="livy_example", + deferrable=True, + ) + with pytest.raises(TaskDeferred): + task.execute(context=self.mock_context) + task.kill() + + mock_delete.assert_called_once_with(BATCH_ID) + + def test_injected_hook_deferrable(self): + def_hook = LivyHook(livy_conn_id="livyunittest") + + task = LivyOperator(file="sparkapp", dag=self.dag, task_id="livy_example", deferrable=True) + task._livy_hook = def_hook + + assert task.get_hook() == def_hook + + @patch( + "airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state", + return_value=BatchState.SUCCESS, + ) + @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_logs", return_value=LOG_RESPONSE) + @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) + @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value=GET_BATCH) + def test_log_dump_deferrable(self, mock_get_batch, mock_post, mock_get_logs, mock_get, caplog): + task = LivyOperator( + livy_conn_id="livyunittest", + file="sparkapp", + dag=self.dag, + task_id="livy_example", + polling_interval=1, + deferrable=True, + ) + caplog.clear() + + with pytest.raises(TaskDeferred): + with caplog.at_level(level=logging.INFO, logger=task.get_hook().log.name): + task.execute(context=self.mock_context) + + assert "first_line" in caplog.messages + assert "second_line" in caplog.messages + assert "third_line" in caplog.messages + + mock_get.assert_called_once_with(BATCH_ID, retry_args=None) + mock_get_logs.assert_called_once_with(BATCH_ID, 0, 100) diff --git a/tests/providers/apache/livy/triggers/__init__.py b/tests/providers/apache/livy/triggers/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/apache/livy/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/apache/livy/triggers/test_livy.py b/tests/providers/apache/livy/triggers/test_livy.py new file mode 100644 index 0000000000000..9e6dc02e75ed2 --- /dev/null +++ b/tests/providers/apache/livy/triggers/test_livy.py @@ -0,0 +1,202 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio + +import pytest +from aiohttp import ClientConnectionError + +from airflow.providers.apache.livy.hooks.livy import BatchState, LivyHook +from airflow.providers.apache.livy.triggers.livy import LivyTrigger +from airflow.triggers.base import TriggerEvent +from tests.providers.apache.livy.compat import async_mock + + +class TestLivyTrigger: + def test_livy_trigger_serialization(self): + """ + Asserts that the TaskStateTrigger correctly serializes its arguments + and classpath. + """ + trigger = LivyTrigger( + batch_id=1, spark_params={}, livy_conn_id=LivyHook.default_conn_name, polling_interval=0 + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.apache.livy.triggers.livy.LivyTrigger" + assert kwargs == { + "batch_id": 1, + "spark_params": {}, + "livy_conn_id": LivyHook.default_conn_name, + "polling_interval": 0, + "extra_options": None, + "extra_headers": None, + "livy_hook_async": None, + } + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.triggers.livy.LivyTrigger.poll_for_termination") + async def test_livy_trigger_run_with_no_poll_interval(self, mock_poll_for_termination): + """ + Test if the task ran in the triggerer successfully with poll interval=0. + In the case when polling_interval=0, it should return the batch_id + """ + mock_poll_for_termination.return_value = {"status": "success"} + trigger = LivyTrigger( + batch_id=1, spark_params={}, livy_conn_id=LivyHook.default_conn_name, polling_interval=0 + ) + generator = trigger.run() + actual = await generator.asend(None) + assert ( + TriggerEvent( + {"status": "success", "batch_id": 1, "response": "Batch 1 succeeded", "log_lines": None} + ) + == actual + ) + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.triggers.livy.LivyTrigger.poll_for_termination") + async def test_livy_trigger_run_with_poll_interval_success(self, mock_poll_for_termination): + """ + Test if the task ran in the triggerer successfully with poll interval>0. In the case when + polling_interval > 0, it should return a success or failure status. + """ + mock_poll_for_termination.return_value = {"status": "success"} + trigger = LivyTrigger( + batch_id=1, spark_params={}, livy_conn_id=LivyHook.default_conn_name, polling_interval=30 + ) + + generator = trigger.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "success"}) == actual + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.triggers.livy.LivyTrigger.poll_for_termination") + async def test_livy_trigger_run_with_poll_interval_error(self, mock_poll_for_termination): + """Test if the task in the trigger returned an error when poll_for_termination returned error.""" + mock_poll_for_termination.return_value = {"status": "error"} + trigger = LivyTrigger( + batch_id=1, spark_params={}, livy_conn_id=LivyHook.default_conn_name, polling_interval=30 + ) + + task = [i async for i in trigger.run()] + assert len(task) == 2 + assert TriggerEvent({"status": "error"}) in task + + @pytest.mark.asyncio + async def test_livy_trigger_run_with_exception(self): + """Test if the task in the trigger failed with a connection error when no connection is mocked.""" + trigger = LivyTrigger( + batch_id=1, spark_params={}, livy_conn_id=LivyHook.default_conn_name, polling_interval=30 + ) + + task = [i async for i in trigger.run()] + assert len(task) == 1 + assert ( + TriggerEvent( + { + "status": "error", + "batch_id": 1, + "response": "Batch 1 did not succeed with Cannot connect to host livy:8998 ssl:default " + "[Name or service not known]", + "log_lines": None, + } + ) + in task + ) + + @pytest.mark.asyncio + async def test_livy_trigger_poll_for_termination_with_client_error(self): + """ + Test if the poll_for_termination() in the trigger failed with a ClientConnectionError + when no connection is mocked. + """ + trigger = LivyTrigger( + batch_id=1, spark_params={}, livy_conn_id=LivyHook.default_conn_name, polling_interval=30 + ) + + with pytest.raises(ClientConnectionError): + await trigger.poll_for_termination(1) + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_batch_state") + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.dump_batch_logs") + async def test_livy_trigger_poll_for_termination_success( + self, mock_dump_batch_logs, mock_get_batch_state + ): + """ + Test if the poll_for_termination() in the triggerer returned success response when get_batch_state() + runs successfully. + """ + mock_get_batch_state.return_value = {"batch_state": BatchState.SUCCESS} + mock_dump_batch_logs.return_value = ["mock_log"] + trigger = LivyTrigger( + batch_id=1, spark_params={}, livy_conn_id=LivyHook.default_conn_name, polling_interval=30 + ) + + task = await trigger.poll_for_termination(1) + + assert task == { + "status": "success", + "batch_id": 1, + "response": "Batch 1 succeeded", + "log_lines": ["mock_log"], + } + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_batch_state") + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.dump_batch_logs") + async def test_livy_trigger_poll_for_termination_error(self, mock_dump_batch_logs, mock_get_batch_state): + """ + Test if the poll_for_termination() in the trigger returned error response when get_batch_state() + failed. + """ + mock_get_batch_state.return_value = {"batch_state": BatchState.ERROR} + mock_dump_batch_logs.return_value = ["mock_log"] + trigger = LivyTrigger( + batch_id=1, spark_params={}, livy_conn_id=LivyHook.default_conn_name, polling_interval=30 + ) + + task = await trigger.poll_for_termination(1) + + assert task == { + "status": "error", + "batch_id": 1, + "response": "Batch 1 did not succeed", + "log_lines": ["mock_log"], + } + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_batch_state") + @async_mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.dump_batch_logs") + async def test_livy_trigger_poll_for_termination_state(self, mock_dump_batch_logs, mock_get_batch_state): + """ + Test if the poll_for_termination() in the trigger is still polling when get_batch_state() returned + NOT_STARTED. + """ + mock_get_batch_state.return_value = {"batch_state": BatchState.NOT_STARTED} + mock_dump_batch_logs.return_value = ["mock_log"] + trigger = LivyTrigger( + batch_id=1, spark_params={}, livy_conn_id=LivyHook.default_conn_name, polling_interval=30 + ) + + task = asyncio.create_task(trigger.poll_for_termination(1)) + await asyncio.sleep(0.5) + + # TriggerEvent was not returned + assert task.done() is False + asyncio.get_event_loop().stop()