From 63d772aaab7ccd9bf1d7ef3244b16008205307e0 Mon Sep 17 00:00:00 2001 From: Qian Sun Date: Tue, 20 Jun 2023 14:55:34 +0800 Subject: [PATCH 1/8] Add Alibaba Cloud AnalyticDB Spark Support Apply suggestions from code review Co-authored-by: Utkarsh Sharma replace Exception with ValueError fix ci deps failure add version remove return in comment fix indent error address comments address comments --- .../alibaba/cloud/hooks/analyticdb_spark.py | 371 ++++++++++++++++++ .../cloud/operators/analyticdb_spark.py | 228 +++++++++++ .../alibaba/cloud/sensors/analyticdb_spark.py | 72 ++++ airflow/providers/alibaba/provider.yaml | 18 + airflow/utils/db.py | 13 + .../connections/alibaba.rst | 2 +- .../index.rst | 18 +- .../operators/analyticdb_spark.rst | 45 +++ docs/conf.py | 2 + docs/spelling_wordlist.txt | 2 + generated/provider_dependencies.json | 2 + .../cloud/hooks/test_analyticdb_spark.py | 193 +++++++++ .../cloud/operators/test_analyticdb_spark.py | 170 ++++++++ .../cloud/sensors/test_analyticdb_spark.py | 69 ++++ .../cloud/utils/analyticdb_spark_mock.py | 41 ++ .../alibaba/example_adb_spark_batch.py | 62 +++ .../alibaba/example_adb_spark_sql.py | 54 +++ 17 files changed, 1353 insertions(+), 9 deletions(-) create mode 100644 airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py create mode 100644 airflow/providers/alibaba/cloud/operators/analyticdb_spark.py create mode 100644 airflow/providers/alibaba/cloud/sensors/analyticdb_spark.py create mode 100644 docs/apache-airflow-providers-alibaba/operators/analyticdb_spark.rst create mode 100644 tests/providers/alibaba/cloud/hooks/test_analyticdb_spark.py create mode 100644 tests/providers/alibaba/cloud/operators/test_analyticdb_spark.py create mode 100644 tests/providers/alibaba/cloud/sensors/test_analyticdb_spark.py create mode 100644 tests/providers/alibaba/cloud/utils/analyticdb_spark_mock.py create mode 100644 tests/system/providers/alibaba/example_adb_spark_batch.py create mode 100644 tests/system/providers/alibaba/example_adb_spark_sql.py diff --git a/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py b/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py new file mode 100644 index 0000000000000..39be927b28a3e --- /dev/null +++ b/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py @@ -0,0 +1,371 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from enum import Enum +from typing import Any, Sequence + +from alibabacloud_adb20211201.client import Client +from alibabacloud_adb20211201.models import ( + GetSparkAppLogRequest, + GetSparkAppStateRequest, + GetSparkAppWebUiAddressRequest, + KillSparkAppRequest, + SubmitSparkAppRequest, + SubmitSparkAppResponse, +) +from alibabacloud_tea_openapi.models import Config + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.utils.log.logging_mixin import LoggingMixin + + +class AppState(Enum): + """ + AnalyticDB Spark application states doc: + https://www.alibabacloud.com/help/en/analyticdb-for-mysql/latest/api-doc-adb-2021-12-01-api-struct + -sparkappinfo. + + """ + + SUBMITTED = "SUBMITTED" + STARTING = "STARTING" + RUNNING = "RUNNING" + FAILING = "FAILING" + FAILED = "FAILED" + KILLING = "KILLING" + KILLED = "KILLED" + SUCCEEDING = "SUCCEEDING" + COMPLETED = "COMPLETED" + FATAL = "FATAL" + UNKNOWN = "UNKNOWN" + + +class AnalyticDBSparkHook(BaseHook, LoggingMixin): + """ + Hook for AnalyticDB MySQL Spark through the REST API. + + :param adb_spark_conn_id: The Airflow connection used for AnalyticDB MySQL Spark credentials. + :param region: AnalyticDB MySQL region you want to submit spark application. + """ + + TERMINAL_STATES = {AppState.COMPLETED, AppState.FAILED, AppState.FATAL, AppState.KILLED} + + conn_name_attr = "alibabacloud_conn_id" + default_conn_name = "adb_spark_default" + conn_type = "adb_spark" + hook_name = "AnalyticDB Spark" + + def __init__( + self, adb_spark_conn_id: str = "adb_spark_default", region: str | None = None, *args, **kwargs + ) -> None: + self.adb_spark_conn_id = adb_spark_conn_id + self.adb_spark_conn = self.get_connection(adb_spark_conn_id) + self.region = self.get_default_region() if region is None else region + super().__init__(*args, **kwargs) + + def submit_spark_app( + self, cluster_id: str, rg_name: str, *args: Any, **kwargs: Any + ) -> SubmitSparkAppResponse: + """ + Perform request to submit spark application. + + :param cluster_id: The cluster ID of AnalyticDB MySQL 3.0 Data Lakehouse. + :param rg_name: The name of resource group in AnalyticDB MySQL 3.0 Data Lakehouse cluster. + """ + self.log.info("Submitting application") + request = SubmitSparkAppRequest( + dbcluster_id=cluster_id, + resource_group_name=rg_name, + data=json.dumps(self.build_submit_app_data(*args, **kwargs)), + app_type="BATCH", + ) + try: + return self.get_adb_spark_client().submit_spark_app(request) + except Exception as e: + self.log.error(e) + raise AirflowException("Errors when submit spark application") from e + + def submit_spark_sql( + self, cluster_id: str, rg_name: str, *args: Any, **kwargs: Any + ) -> SubmitSparkAppResponse: + """ + Perform request to submit spark sql. + + :param cluster_id: The cluster ID of AnalyticDB MySQL 3.0 Data Lakehouse. + :param rg_name: The name of resource group in AnalyticDB MySQL 3.0 Data Lakehouse cluster. + """ + self.log.info("Submitting Spark SQL") + request = SubmitSparkAppRequest( + dbcluster_id=cluster_id, + resource_group_name=rg_name, + data=self.build_submit_sql_data(*args, **kwargs), + app_type="SQL", + ) + try: + return self.get_adb_spark_client().submit_spark_app(request) + except Exception as e: + self.log.error(e) + raise AirflowException("Errors when submit spark sql") from e + + def get_spark_state(self, app_id: str) -> str: + """ + Fetch the state of the specified spark application. + + :param app_id: identifier of the spark application + """ + self.log.debug("Fetching state for spark application %s", app_id) + try: + return ( + self.get_adb_spark_client() + .get_spark_app_state(GetSparkAppStateRequest(app_id=app_id)) + .body.data.state + ) + except Exception as e: + self.log.error(e) + raise AirflowException(f"Errors when fetching state for spark application: {app_id}") from e + + def get_spark_web_ui_address(self, app_id: str) -> str: + """ + Fetch the web ui address of the specified spark application. + + :param app_id: identifier of the spark application + """ + self.log.debug("Fetching web ui address for spark application %s", app_id) + try: + return ( + self.get_adb_spark_client() + .get_spark_app_web_ui_address(GetSparkAppWebUiAddressRequest(app_id=app_id)) + .body.data.web_ui_address + ) + except Exception as e: + self.log.error(e) + raise AirflowException( + f"Errors when fetching web ui address for spark application: {app_id}" + ) from e + + def get_spark_log(self, app_id: str) -> str: + """ + Get the logs for a specified spark application. + + :param app_id: identifier of the spark application + """ + self.log.debug("Fetching log for spark application %s", app_id) + try: + return ( + self.get_adb_spark_client() + .get_spark_app_log(GetSparkAppLogRequest(app_id=app_id)) + .body.data.log_content + ) + except Exception as e: + self.log.error(e) + raise AirflowException( + f"Errors when fetching log for spark application: {app_id}" + ) from e + + def kill_spark_app(self, app_id: str) -> None: + """ + Kill the specified spark application. + + :param app_id: identifier of the spark application + """ + self.log.info("Killing spark application %s", app_id) + try: + self.get_adb_spark_client().kill_spark_app(KillSparkAppRequest(app_id=app_id)) + except Exception as e: + self.log.error(e) + raise AirflowException(f"Errors when killing spark application: {app_id}") from e + + @staticmethod + def build_submit_app_data( + file: str | None = None, + class_name: str | None = None, + args: Sequence[str | int | float] | None = None, + conf: dict[Any, Any] | None = None, + jars: Sequence[str] | None = None, + py_files: Sequence[str] | None = None, + files: Sequence[str] | None = None, + driver_resource_spec: str | None = None, + executor_resource_spec: str | None = None, + num_executors: int | str | None = None, + archives: Sequence[str] | None = None, + name: str | None = None, + ) -> dict: + """ + Build the submit application request data + + :param file: path of the file containing the application to execute. + :param class_name: name of the application Java/Spark main class. + :param args: application command line arguments. + :param conf: Spark configuration properties. + :param jars: jars to be used in this application. + :param py_files: python files to be used in this application. + :param files: files to be used in this application. + :param driver_resource_spec: The resource specifications of the Spark driver. + :param executor_resource_spec: The resource specifications of each Spark executor. + :param num_executors: number of executors to launch for this application. + :param archives: archives to be used in this application. + :param name: name of this application. + """ + if file is None: + raise ValueError("Parameter file is need when submit spark application.") + + data: dict[str, Any] = {"file": file} + extra_conf: dict[str, str] = {} + + if class_name: + data["className"] = class_name + if args and AnalyticDBSparkHook._validate_list_of_stringables(args): + data["args"] = [str(val) for val in args] + if driver_resource_spec: + extra_conf["spark.driver.resourceSpec"] = driver_resource_spec + if executor_resource_spec: + extra_conf["spark.executor.resourceSpec"] = executor_resource_spec + if num_executors: + extra_conf["spark.executor.instances"] = str(num_executors) + data["conf"] = extra_conf.copy() + if conf and AnalyticDBSparkHook._validate_extra_conf(conf): + data["conf"].update(conf) + if jars and AnalyticDBSparkHook._validate_list_of_stringables(jars): + data["jars"] = jars + if py_files and AnalyticDBSparkHook._validate_list_of_stringables(py_files): + data["pyFiles"] = py_files + if files and AnalyticDBSparkHook._validate_list_of_stringables(files): + data["files"] = files + if archives and AnalyticDBSparkHook._validate_list_of_stringables(archives): + data["archives"] = archives + if name: + data["name"] = name + + return data + + @staticmethod + def build_submit_sql_data( + sql: str | None = None, + conf: dict[Any, Any] | None = None, + driver_resource_spec: str | None = None, + executor_resource_spec: str | None = None, + num_executors: int | str | None = None, + name: str | None = None, + ) -> str: + """ + Build the submit spark sql request data. + + :param sql: The SQL query to execute. (templated) + :param conf: Spark configuration properties. + :param driver_resource_spec: The resource specifications of the Spark driver. + :param executor_resource_spec: The resource specifications of each Spark executor. + :param num_executors: number of executors to launch for this application. + :param name: name of this application. + """ + if sql is None: + raise ValueError("Parameter sql is need when submit spark sql.") + + extra_conf: dict[str, str] = {} + formatted_conf = "" + + if driver_resource_spec: + extra_conf["spark.driver.resourceSpec"] = driver_resource_spec + if executor_resource_spec: + extra_conf["spark.executor.resourceSpec"] = executor_resource_spec + if num_executors: + extra_conf["spark.executor.instances"] = str(num_executors) + if name: + extra_conf["spark.app.name"] = name + if conf and AnalyticDBSparkHook._validate_extra_conf(conf): + extra_conf.update(conf) + for key, value in extra_conf.items(): + formatted_conf += f"set {key} = {value};" + + return (formatted_conf + sql).strip() + + @staticmethod + def _validate_list_of_stringables(vals: Sequence[str | int | float]) -> bool: + """ + Check the values in the provided list can be converted to strings. + + :param vals: list to validate + :return: true if valid + """ + if ( + vals is None + or not isinstance(vals, (tuple, list)) + or any(1 for val in vals if not isinstance(val, (str, int, float))) + ): + raise ValueError("List of strings expected") + return True + + @staticmethod + def _validate_extra_conf(conf: dict[Any, Any]) -> bool: + """ + Check configuration values are either strings or ints. + + :param conf: configuration variable + :return: true if valid + """ + if conf: + if not isinstance(conf, dict): + raise ValueError("'conf' argument must be a dict") + if any(True for k, v in conf.items() if not (v and isinstance(v, str) or isinstance(v, int))): + raise ValueError("'conf' values must be either strings or ints") + return True + + def get_adb_spark_client(self) -> Client: + """Get valid AnalyticDB MySQL Spark client.""" + assert self.region is not None + + extra_config = self.adb_spark_conn.extra_dejson + auth_type = extra_config.get("auth_type", None) + if not auth_type: + raise ValueError("No auth_type specified in extra_config.") + + if auth_type != "AK": + raise ValueError(f"Unsupported auth_type: {auth_type}") + adb_spark_access_key_id = extra_config.get("access_key_id", None) + adb_spark_access_secret = extra_config.get("access_key_secret", None) + if not adb_spark_access_key_id: + raise ValueError(f"No access_key_id is specified for connection: {self.adb_spark_conn_id}") + + if not adb_spark_access_secret: + raise ValueError(f"No access_key_secret is specified for connection: {self.adb_spark_conn_id}") + + return Client( + Config( + access_key_id=adb_spark_access_key_id, + access_key_secret=adb_spark_access_secret, + endpoint=f"adb.{self.region}.aliyuncs.com", + ) + ) + + def get_default_region(self) -> str | None: + """Get default region from connection.""" + + extra_config = self.adb_spark_conn.extra_dejson + auth_type = extra_config.get("auth_type", None) + if not auth_type: + raise ValueError("No auth_type specified in extra_config. ") + + if auth_type != "AK": + raise ValueError(f"Unsupported auth_type: {auth_type}") + + default_region = extra_config.get("region", None) + if not default_region: + raise ValueError(f"No region is specified for connection: {self.adb_spark_conn}") + return default_region diff --git a/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py b/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py new file mode 100644 index 0000000000000..8da771b67ee8e --- /dev/null +++ b/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py @@ -0,0 +1,228 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from functools import cached_property +from time import sleep +from typing import TYPE_CHECKING, Any, Sequence + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator + +from airflow.providers.alibaba.cloud.hooks.analyticdb_spark import AnalyticDBSparkHook, AppState + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class AnalyticDBSparkBaseOperator(BaseOperator): + """Abstract base class that defines how users develop AnalyticDB Spark.""" + + def __init__( + self, + *, + adb_spark_conn_id: str = "adb_spark_default", + region: str | None = None, + polling_interval: int = 0, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + self.app_id: str | None = None + self.polling_interval = polling_interval + + self._adb_spark_conn_id = adb_spark_conn_id + self._region = region + + self._adb_spark_hook: AnalyticDBSparkHook | None = None + + @cached_property + def get_hook(self) -> AnalyticDBSparkHook: + """ + Get valid hook. + + :return: hook + """ + if self._adb_spark_hook is None or not isinstance(self._adb_spark_hook, AnalyticDBSparkHook): + self._adb_spark_hook = AnalyticDBSparkHook( + adb_spark_conn_id=self._adb_spark_conn_id, region=self._region + ) + return self._adb_spark_hook + + def execute(self, context: Context) -> Any: + ... + + def monitor_application(self): + self.log.info("Monitoring application with %s", self.app_id) + + if self.polling_interval > 0: + self.poll_for_termination(self.app_id) + + def poll_for_termination(self, app_id: str) -> None: + """ + Pool for spark application termination. + + :param app_id: id of the spark application to monitor + """ + hook = self.get_hook + state = hook.get_spark_state(app_id) + while AppState(state) not in AnalyticDBSparkHook.TERMINAL_STATES: + self.log.debug("Application with id %s is in state: %s", app_id, state) + sleep(self.polling_interval) + state = hook.get_spark_state(app_id) + self.log.info("Application with id %s terminated with state: %s", app_id, state) + self.log.info( + "Web ui address is %s for application with id %s", hook.get_spark_web_ui_address(app_id), app_id + ) + self.log.info(hook.get_spark_log(app_id)) + if AppState(state) != AppState.COMPLETED: + raise AirflowException(f"Application {app_id} did not succeed") + + def on_kill(self) -> None: + self.kill() + + def kill(self) -> None: + """Delete the specified application.""" + if self.app_id is not None: + self.get_hook.kill_spark_app(self.app_id) + + +class AnalyticDBSparkSQLOperator(AnalyticDBSparkBaseOperator): + """ + This operator warps the AnalyticDB Spark REST API, allowing to submit a Spark sql + application to the underlying cluster. + + :param sql: The SQL query to execute. + :param conf: Spark configuration properties. + :param driver_resource_spec: The resource specifications of the Spark driver. + :param executor_resource_spec: The resource specifications of each Spark executor. + :param num_executors: number of executors to launch for this application. + :param name: name of this application. + :param cluster_id: The cluster ID of AnalyticDB MySQL 3.0 Data Lakehouse. + :param rg_name: The name of resource group in AnalyticDB MySQL 3.0 Data Lakehouse cluster. + """ + + template_fields: Sequence[str] = ("spark_params",) + template_fields_renderers = {"spark_params": "json"} + + def __init__( + self, + *, + sql: str, + conf: dict[Any, Any] | None = None, + driver_resource_spec: str | None = None, + executor_resource_spec: str | None = None, + num_executors: int | str | None = None, + name: str | None = None, + cluster_id: str, + rg_name: str, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + self.spark_params = { + "sql": sql, + "conf": conf, + "driver_resource_spec": driver_resource_spec, + "executor_resource_spec": executor_resource_spec, + "num_executors": num_executors, + "name": name, + } + + self._cluster_id = cluster_id + self._rg_name = rg_name + + def execute(self, context: Context) -> Any: + submit_response = self.get_hook.submit_spark_sql( + cluster_id=self._cluster_id, rg_name=self._rg_name, **self.spark_params + ) + self.app_id = submit_response.body.data.app_id + self.monitor_application() + return self.app_id + + +class AnalyticDBSparkBatchOperator(AnalyticDBSparkBaseOperator): + """ + This operator warps the AnalyticDB Spark REST API, allowing to submit a Spark batch + application to the underlying cluster. + + :param file: path of the file containing the application to execute. + :param class_name: name of the application Java/Spark main class. + :param args: application command line arguments. + :param conf: Spark configuration properties. + :param jars: jars to be used in this application. + :param py_files: python files to be used in this application. + :param files: files to be used in this application. + :param driver_resource_spec: The resource specifications of the Spark driver. + :param executor_resource_spec: The resource specifications of each Spark executor. + :param num_executors: number of executors to launch for this application. + :param archives: archives to be used in this application. + :param name: name of this application. + :param cluster_id: The cluster ID of AnalyticDB MySQL 3.0 Data Lakehouse. + :param rg_name: The name of resource group in AnalyticDB MySQL 3.0 Data Lakehouse cluster. + """ + + template_fields: Sequence[str] = ("spark_params",) + template_fields_renderers = {"spark_params": "json"} + + def __init__( + self, + *, + file: str, + class_name: str | None = None, + args: Sequence[str | int | float] | None = None, + conf: dict[Any, Any] | None = None, + jars: Sequence[str] | None = None, + py_files: Sequence[str] | None = None, + files: Sequence[str] | None = None, + driver_resource_spec: str | None = None, + executor_resource_spec: str | None = None, + num_executors: int | str | None = None, + archives: Sequence[str] | None = None, + name: str | None = None, + cluster_id: str, + rg_name: str, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + self.spark_params = { + "file": file, + "class_name": class_name, + "args": args, + "conf": conf, + "jars": jars, + "py_files": py_files, + "files": files, + "driver_resource_spec": driver_resource_spec, + "executor_resource_spec": executor_resource_spec, + "num_executors": num_executors, + "archives": archives, + "name": name, + } + + self._cluster_id = cluster_id + self._rg_name = rg_name + + def execute(self, context: Context) -> Any: + submit_response = self.get_hook.submit_spark_app( + cluster_id=self._cluster_id, rg_name=self._rg_name, **self.spark_params + ) + self.app_id = submit_response.body.data.app_id + self.monitor_application() + return self.app_id diff --git a/airflow/providers/alibaba/cloud/sensors/analyticdb_spark.py b/airflow/providers/alibaba/cloud/sensors/analyticdb_spark.py new file mode 100644 index 0000000000000..b0caa988858c0 --- /dev/null +++ b/airflow/providers/alibaba/cloud/sensors/analyticdb_spark.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from functools import cached_property +from typing import TYPE_CHECKING, Any, Sequence + +from airflow.providers.alibaba.cloud.hooks.analyticdb_spark import AnalyticDBSparkHook, AppState +from airflow.sensors.base import BaseSensorOperator + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class AnalyticDBSparkSensor(BaseSensorOperator): + """ + Monitor a AnalyticDB Spark session for termination. + + :param app_id: identifier of the monitored app depends on the option that's being modified. + :param adb_spark_conn_id: reference to a pre-defined ADB Spark connection. + :param region: AnalyticDB MySQL region you want to submit spark application. + """ + + template_fields: Sequence[str] = ("app_id",) + + def __init__( + self, + *, + app_id: str, + adb_spark_conn_id: str = "adb_spark_default", + region: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.app_id = app_id + self._region = region + self._adb_spark_conn_id = adb_spark_conn_id + self._adb_spark_hook: AnalyticDBSparkHook | None = None + + @cached_property + def get_hook(self) -> AnalyticDBSparkHook: + """ + Get valid hook. + + :return: hook + """ + if self._adb_spark_hook is None or not isinstance(self._adb_spark_hook, AnalyticDBSparkHook): + self._adb_spark_hook = AnalyticDBSparkHook( + adb_spark_conn_id=self._adb_spark_conn_id, region=self._region + ) + return self._adb_spark_hook + + def poke(self, context: Context) -> bool: + app_id = self.app_id + + state = self.get_hook.get_spark_state(app_id) + return AppState(state) in AnalyticDBSparkHook.TERMINAL_STATES diff --git a/airflow/providers/alibaba/provider.yaml b/airflow/providers/alibaba/provider.yaml index ba926a9c1331d..fe0609ba8f07e 100644 --- a/airflow/providers/alibaba/provider.yaml +++ b/airflow/providers/alibaba/provider.yaml @@ -37,6 +37,8 @@ versions: dependencies: - apache-airflow>=2.4.0 - oss2>=2.14.0 + - alibabacloud_adb20211201>=1.0.0 + - alibabacloud_tea_openapi>=0.3.7 integrations: - integration-name: Alibaba Cloud OSS @@ -45,26 +47,42 @@ integrations: how-to-guide: - /docs/apache-airflow-providers-alibaba/operators/oss.rst tags: [alibaba] + - integration-name: Alibaba Cloud AnalyticDB Spark + external-doc-url: https://www.alibabacloud.com/help/en/analyticdb-for-mysql/latest/spark-offline-application-development + how-to-guide: + - /docs/apache-airflow-providers-alibaba/operators/analyticdb_spark.rst + tags: [alibaba] operators: - integration-name: Alibaba Cloud OSS python-modules: - airflow.providers.alibaba.cloud.operators.oss + - integration-name: Alibaba Cloud AnalyticDB Spark + python-modules: + - airflow.providers.alibaba.cloud.operators.analyticdb_spark sensors: - integration-name: Alibaba Cloud OSS python-modules: - airflow.providers.alibaba.cloud.sensors.oss_key + - integration-name: Alibaba Cloud AnalyticDB Spark + python-modules: + - airflow.providers.alibaba.cloud.sensors.analyticdb_spark hooks: - integration-name: Alibaba Cloud OSS python-modules: - airflow.providers.alibaba.cloud.hooks.oss + - integration-name: Alibaba Cloud AnalyticDB Spark + python-modules: + - airflow.providers.alibaba.cloud.hooks.analyticdb_spark connection-types: - hook-class-name: airflow.providers.alibaba.cloud.hooks.oss.OSSHook connection-type: oss + - hook-class-name: airflow.providers.alibaba.cloud.hooks.analyticdb_spark.AnalyticDBSparkHook + connection-type: adb_spark logging: - airflow.providers.alibaba.cloud.log.oss_task_handler.OSSTaskHandler diff --git a/airflow/utils/db.py b/airflow/utils/db.py index a76f0d4f675d1..4b4bfed0adf78 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -477,6 +477,19 @@ def create_default_connections(session: Session = NEW_SESSION): ), session, ) + merge_conn( + Connection( + conn_id="adb_spark_default", + conn_type="adb_spark", + extra="""{ + "auth_type": "AK", + "access_key_id": "", + "access_key_secret": "", + "region": ""} + """, + ), + session, + ) merge_conn( Connection( conn_id="pig_cli_default", diff --git a/docs/apache-airflow-providers-alibaba/connections/alibaba.rst b/docs/apache-airflow-providers-alibaba/connections/alibaba.rst index d697b4dc0c753..4cf4747d7e2a5 100644 --- a/docs/apache-airflow-providers-alibaba/connections/alibaba.rst +++ b/docs/apache-airflow-providers-alibaba/connections/alibaba.rst @@ -26,7 +26,7 @@ Authentication may be performed using `Security Token Service (STS) or a signed Default Connection IDs ---------------------- -The default connection ID is ``oss_default``. +The default connection IDs are ``oss_default`` and ``adb_spark_default``. Configuring the Connection -------------------------- diff --git a/docs/apache-airflow-providers-alibaba/index.rst b/docs/apache-airflow-providers-alibaba/index.rst index b82fda55546c6..8d15929cb5030 100644 --- a/docs/apache-airflow-providers-alibaba/index.rst +++ b/docs/apache-airflow-providers-alibaba/index.rst @@ -65,7 +65,7 @@ Package apache-airflow-providers-alibaba Alibaba Cloud integration (including `Alibaba Cloud `__). -Release: 2.4.0 +Release: 2.5.0 Provider package ---------------- @@ -83,13 +83,15 @@ for the minimum Airflow version supported) via Requirements ------------ -The minimum Apache Airflow version supported by this provider package is ``2.4.0``. +The minimum Apache Airflow version supported by this provider package is ``2.5.0``. -================== ================== -PIP package Version required -================== ================== -``apache-airflow`` ``>=2.4.0`` -``oss2`` ``>=2.14.0`` -================== ================== +============================ ================== +PIP package Version required +============================ ================== +``apache-airflow`` ``>=2.5.0`` +``oss2`` ``>=2.14.0`` +``alibabacloud_adb20211201`` ``>=1.0.0`` +``alibabacloud_tea_openapi`` ``>=0.3.7`` +============================ ================== .. include:: ../../airflow/providers/alibaba/CHANGELOG.rst diff --git a/docs/apache-airflow-providers-alibaba/operators/analyticdb_spark.rst b/docs/apache-airflow-providers-alibaba/operators/analyticdb_spark.rst new file mode 100644 index 0000000000000..ac3f0638ad24c --- /dev/null +++ b/docs/apache-airflow-providers-alibaba/operators/analyticdb_spark.rst @@ -0,0 +1,45 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Alibaba Cloud AnalyticDB Spark Operators +======================================== + +Overview +-------- + +Airflow to Alibaba Cloud AnalyticDB Spark integration provides several operators to develop spark batch and sql applications. + + - :class:`~airflow.providers.alibaba.cloud.operators.analyticdb_spark.AnalyticDBSparkBatchOperator` + - :class:`~airflow.providers.alibaba.cloud.operators.analyticdb_spark.AnalyticDBSparkSQLOperator` + +Develop Spark batch applications +------------------------------------------- + +Purpose +""""""" + +This example dag uses ``AnalyticDBSparkBatchOperator`` to submit Spark Pi and Spark Logistic regression applications. + +Defining tasks +"""""""""""""" + +In the following code we submit Spark Pi and Spark Logistic regression applications. + +.. exampleinclude:: /../../tests/system/providers/alibaba/example_adb_spark_batch.py + :language: python + :start-after: [START howto_operator_adb_spark_batch] + :end-before: [END howto_operator_adb_spark_batch] diff --git a/docs/conf.py b/docs/conf.py index 27b3e07384d90..f45956c9c0c88 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -588,6 +588,8 @@ def _get_params(root_schema: dict, prefix: str = "", default_section: str = "") autodoc_mock_imports = [ "MySQLdb", "adal", + "alibabacloud_adb20211201", + "alibabacloud_tea_openapi", "analytics", "azure", "azure.cosmos", diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index babb506f4108f..6dee5c2ea2356 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -861,6 +861,7 @@ kwargs KYLIN Kylin kylin +Lakehouse LanguageServiceClient lastname latencies @@ -1357,6 +1358,7 @@ sourceRepository sourceUploadUrl Spark sparkApplication +sparkappinfo sparkcmd SparkPi SparkR diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index d77fd32e8d46d..a25786120f6c9 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -11,6 +11,8 @@ }, "alibaba": { "deps": [ + "alibabacloud_adb20211201>=1.0.0", + "alibabacloud_tea_openapi>=0.3.7", "apache-airflow>=2.4.0", "oss2>=2.14.0" ], diff --git a/tests/providers/alibaba/cloud/hooks/test_analyticdb_spark.py b/tests/providers/alibaba/cloud/hooks/test_analyticdb_spark.py new file mode 100644 index 0000000000000..dba4ee18b2a8c --- /dev/null +++ b/tests/providers/alibaba/cloud/hooks/test_analyticdb_spark.py @@ -0,0 +1,193 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from alibabacloud_adb20211201.models import ( + GetSparkAppLogResponse, + GetSparkAppLogResponseBody, + GetSparkAppLogResponseBodyData, + GetSparkAppStateResponse, + GetSparkAppStateResponseBody, + GetSparkAppStateResponseBodyData, + GetSparkAppWebUiAddressResponse, + GetSparkAppWebUiAddressResponseBody, + GetSparkAppWebUiAddressResponseBodyData, + KillSparkAppResponse, + SubmitSparkAppResponse, +) + +from airflow.providers.alibaba.cloud.hooks.analyticdb_spark import AnalyticDBSparkHook +from tests.providers.alibaba.cloud.utils.analyticdb_spark_mock import mock_adb_spark_hook_default_project_id + +ADB_SPARK_STRING = "airflow.providers.alibaba.cloud.hooks.analyticdb_spark.{}" +MOCK_ADB_SPARK_CONN_ID = "mock_id" +MOCK_ADB_CLUSTER_ID = "mock_adb_cluster_id" +MOCK_ADB_RG_NAME = "mock_adb_rg_name" +MOCK_ADB_SPARK_ID = "mock_adb_spark_id" + + +class TestAnalyticDBSparkHook: + def setup_method(self): + with mock.patch( + ADB_SPARK_STRING.format("AnalyticDBSparkHook.__init__"), + new=mock_adb_spark_hook_default_project_id, + ): + self.hook = AnalyticDBSparkHook(adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID) + + def test_build_submit_app_data(self): + res_data = self.hook.build_submit_app_data( + file="oss://test_file", + class_name="com.aliyun.spark.SparkPi", + args=[1000, "test-args"], + conf={"spark.executor.instances": 1, "spark.eventLog.enabled": "true"}, + jars=["oss://1.jar", "oss://2.jar"], + py_files=["oss://1.py", "oss://2.py"], + files=["oss://1.file", "oss://2.file"], + driver_resource_spec="medium", + executor_resource_spec="medium", + num_executors=2, + archives=["oss://1.zip", "oss://2.zip"], + name="test", + ) + except_data = { + "file": "oss://test_file", + "className": "com.aliyun.spark.SparkPi", + "args": ["1000", "test-args"], + "conf": { + "spark.executor.instances": 1, + "spark.eventLog.enabled": "true", + "spark.driver.resourceSpec": "medium", + "spark.executor.resourceSpec": "medium", + }, + "jars": ["oss://1.jar", "oss://2.jar"], + "pyFiles": ["oss://1.py", "oss://2.py"], + "files": ["oss://1.file", "oss://2.file"], + "archives": ["oss://1.zip", "oss://2.zip"], + "name": "test", + } + assert res_data == except_data + + def test_build_submit_sql_data(self): + res_data = self.hook.build_submit_sql_data( + sql=""" + set spark.executor.instances=1; + show databases; + """, + conf={"spark.executor.instances": 2}, + driver_resource_spec="medium", + executor_resource_spec="medium", + num_executors=3, + name="test", + ) + except_data = "set spark.driver.resourceSpec = medium;set spark.executor.resourceSpec = medium;set " \ + "spark.executor.instances = 2;set spark.app.name = test;\n set " \ + "spark.executor.instances=1;\n show databases;" + assert res_data == except_data + + @mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client")) + def test_submit_spark_app(self, mock_service): + # Given + mock_client = mock_service.return_value + exists_method = mock_client.submit_spark_app + exists_method.return_value = SubmitSparkAppResponse(status_code=200) + + # When + res = self.hook.submit_spark_app(MOCK_ADB_CLUSTER_ID, MOCK_ADB_RG_NAME, "oss://test.py") + + # Then + assert isinstance(res, SubmitSparkAppResponse) + mock_service.assert_called_once_with() + + @mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client")) + def test_submit_spark_sql(self, mock_service): + # Given + mock_client = mock_service.return_value + exists_method = mock_client.submit_spark_app + exists_method.return_value = SubmitSparkAppResponse(status_code=200) + + # When + res = self.hook.submit_spark_sql(MOCK_ADB_CLUSTER_ID, MOCK_ADB_RG_NAME, "SELECT 1") + + # Then + assert isinstance(res, SubmitSparkAppResponse) + mock_service.assert_called_once_with() + + @mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client")) + def test_get_spark_state(self, mock_service): + # Given + mock_client = mock_service.return_value + exists_method = mock_client.get_spark_app_state + exists_method.return_value = GetSparkAppStateResponse( + body=GetSparkAppStateResponseBody(data=GetSparkAppStateResponseBodyData(state="RUNNING")) + ) + + # When + res = self.hook.get_spark_state(MOCK_ADB_SPARK_ID) + + # Then + assert res == "RUNNING" + mock_service.assert_called_once_with() + + @mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client")) + def test_get_spark_web_ui_address(self, mock_service): + # Given + mock_client = mock_service.return_value + exists_method = mock_client.get_spark_app_web_ui_address + exists_method.return_value = GetSparkAppWebUiAddressResponse( + body=GetSparkAppWebUiAddressResponseBody( + data=GetSparkAppWebUiAddressResponseBodyData(web_ui_address="https://mock-web-ui-address.com") + ) + ) + + # When + res = self.hook.get_spark_web_ui_address(MOCK_ADB_SPARK_ID) + + # Then + assert res == "https://mock-web-ui-address.com" + mock_service.assert_called_once_with() + + @mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client")) + def test_get_spark_log(self, mock_service): + # Given + mock_client = mock_service.return_value + exists_method = mock_client.get_spark_app_log + exists_method.return_value = GetSparkAppLogResponse( + body=GetSparkAppLogResponseBody(data=GetSparkAppLogResponseBodyData(log_content="Pi is 3.14")) + ) + + # When + res = self.hook.get_spark_log(MOCK_ADB_SPARK_ID) + + # Then + assert res == "Pi is 3.14" + mock_service.assert_called_once_with() + + @mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client")) + def test_kill_spark_app(self, mock_service): + # Given + mock_client = mock_service.return_value + exists_method = mock_client.kill_spark_app + exists_method.return_value = KillSparkAppResponse() + + # When + self.hook.kill_spark_app(MOCK_ADB_SPARK_ID) + + # Then + mock_service.assert_called_once_with() diff --git a/tests/providers/alibaba/cloud/operators/test_analyticdb_spark.py b/tests/providers/alibaba/cloud/operators/test_analyticdb_spark.py new file mode 100644 index 0000000000000..e5bfe65b1edff --- /dev/null +++ b/tests/providers/alibaba/cloud/operators/test_analyticdb_spark.py @@ -0,0 +1,170 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow import AirflowException +from airflow.providers.alibaba.cloud.operators.analyticdb_spark import ( + AnalyticDBSparkBatchOperator, + AnalyticDBSparkBaseOperator, + AnalyticDBSparkSQLOperator, +) + +ADB_SPARK_OPERATOR_STRING = "airflow.providers.alibaba.cloud.operators.analyticdb_spark.{}" + +MOCK_FILE = "oss://test.py" +MOCK_CLUSTER_ID = "mock_cluster_id" +MOCK_RG_NAME = "mock_rg_name" +MOCK_ADB_SPARK_CONN_ID = "mock_adb_spark_conn_id" +MOCK_REGION = "mock_region" +MOCK_TASK_ID = "mock_task_id" +MOCK_APP_ID = "mock_app_id" +MOCK_SQL = "SELECT 1;" + + +class TestAnalyticDBSparkBaseOperator: + + def setup_method(self): + self.operator = AnalyticDBSparkBaseOperator( + adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, + region=MOCK_REGION, + task_id=MOCK_TASK_ID, + ) + + @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkHook")) + def test_get_hook(self, mock_hook): + self.operator.get_hook() + mock_hook.assert_called_once_with(adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, region=MOCK_REGION) + + @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkBaseOperator.get_hook")) + def test_poll_for_termination(self, mock_hook): + # Given + mock_hook.return_value.get_spark_state.return_value = "COMPLETED" + + # When + self.operator.poll_for_termination(MOCK_APP_ID) + + @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkBaseOperator.get_hook")) + def test_poll_for_termination_with_exception(self, mock_hook): + # Given + mock_hook.return_value.get_spark_state.return_value = "FATAL" + + # When + with pytest.raises(AirflowException, match="Application mock_app_id did not succeed"): + self.operator.poll_for_termination(MOCK_APP_ID) + + +class TestAnalyticDBSparkBatchOperator: + @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkHook")) + def test_execute(self, mock_hook): + operator = AnalyticDBSparkBatchOperator( + file=MOCK_FILE, + cluster_id=MOCK_CLUSTER_ID, + rg_name=MOCK_RG_NAME, + adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, + region=MOCK_REGION, + task_id=MOCK_TASK_ID, + ) + + operator.execute(None) + + mock_hook.assert_called_once_with(adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, region=MOCK_REGION) + mock_hook.return_value.submit_spark_app.assert_called_once_with( + cluster_id=MOCK_CLUSTER_ID, + rg_name=MOCK_RG_NAME, + file=MOCK_FILE, + class_name=None, + args=None, + conf=None, + jars=None, + py_files=None, + files=None, + driver_resource_spec=None, + executor_resource_spec=None, + num_executors=None, + archives=None, + name=None, + ) + + @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkBaseOperator.get_hook")) + def test_execute_with_exception(self, mock_hook): + # Given + mock_hook.return_value.submit_spark_app.side_effect = ValueError("List of strings expected") + + # When + operator = AnalyticDBSparkBatchOperator( + file=MOCK_FILE, + args=(True, False), + cluster_id=MOCK_CLUSTER_ID, + rg_name=MOCK_RG_NAME, + adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, + region=MOCK_REGION, + task_id=MOCK_TASK_ID, + ) + + with pytest.raises(ValueError, match="List of strings expected"): + operator.execute(None) + + +class TestAnalyticDBSparklSQLOperator: + @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkHook")) + def test_execute(self, mock_hook): + operator = AnalyticDBSparkSQLOperator( + sql=MOCK_SQL, + cluster_id=MOCK_CLUSTER_ID, + rg_name=MOCK_RG_NAME, + adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, + region=MOCK_REGION, + task_id=MOCK_TASK_ID, + ) + + operator.execute(None) + + mock_hook.assert_called_once_with(adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, region=MOCK_REGION) + mock_hook.return_value.submit_spark_sql.assert_called_once_with( + cluster_id=MOCK_CLUSTER_ID, + rg_name=MOCK_RG_NAME, + sql=MOCK_SQL, + conf=None, + driver_resource_spec=None, + executor_resource_spec=None, + num_executors=None, + name=None, + ) + + @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkBaseOperator.get_hook")) + def test_execute_with_exception(self, mock_hook): + # Given + mock_hook.return_value.submit_spark_sql.side_effect = ValueError("List of strings expected") + + # When + operator = AnalyticDBSparkSQLOperator( + sql=MOCK_SQL, + conf={"spark.eventLog.enabled": True}, + cluster_id=MOCK_CLUSTER_ID, + rg_name=MOCK_RG_NAME, + adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, + region=MOCK_REGION, + task_id=MOCK_TASK_ID, + ) + + with pytest.raises(ValueError, match="List of strings expected"): + operator.execute(None) diff --git a/tests/providers/alibaba/cloud/sensors/test_analyticdb_spark.py b/tests/providers/alibaba/cloud/sensors/test_analyticdb_spark.py new file mode 100644 index 0000000000000..cea0b625d887d --- /dev/null +++ b/tests/providers/alibaba/cloud/sensors/test_analyticdb_spark.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from airflow.providers.alibaba.cloud.sensors.analyticdb_spark import AnalyticDBSparkSensor +from airflow.utils import timezone + +ADB_SPARK_SENSOR_STRING = "airflow.providers.alibaba.cloud.sensors.analyticdb_spark.{}" +DEFAULT_DATE = timezone.datetime(2017, 1, 1) +MOCK_ADB_SPARK_CONN_ID = "mock_adb_spark_default" +MOCK_ADB_SPARK_ID = "mock_adb_spark_id" +MOCK_SENSOR_TASK_ID = "test-adb-spark-operator" +MOCK_REGION = "mock_region" + + +class TestAnalyticDBSparkSensor: + def setup_method(self): + self.sensor = AnalyticDBSparkSensor( + app_id=MOCK_ADB_SPARK_ID, + adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, + region=MOCK_REGION, + task_id=MOCK_SENSOR_TASK_ID, + ) + + @mock.patch(ADB_SPARK_SENSOR_STRING.format("AnalyticDBSparkHook")) + def test_get_hook(self, mock_service): + self.sensor.get_hook() + mock_service.assert_called_once_with(adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, region=MOCK_REGION) + + @mock.patch(ADB_SPARK_SENSOR_STRING.format("AnalyticDBSparkSensor.get_hook")) + def test_poke_terminal_state(self, mock_service): + # Given + mock_service.return_value.get_spark_state.return_value = "COMPLETED" + + # When + res = self.sensor.poke(None) + + # Then + assert res is True + mock_service.return_value.get_spark_state.assert_called_once_with(MOCK_ADB_SPARK_ID) + + @mock.patch(ADB_SPARK_SENSOR_STRING.format("AnalyticDBSparkSensor.get_hook")) + def test_poke_non_terminal_state(self, mock_service): + # Given + mock_service.return_value.get_spark_state.return_value = "RUNNING" + + # When + res = self.sensor.poke(None) + + # Then + assert res is False + mock_service.return_value.get_spark_state.assert_called_once_with(MOCK_ADB_SPARK_ID) diff --git a/tests/providers/alibaba/cloud/utils/analyticdb_spark_mock.py b/tests/providers/alibaba/cloud/utils/analyticdb_spark_mock.py new file mode 100644 index 0000000000000..b43cc1feb2c4c --- /dev/null +++ b/tests/providers/alibaba/cloud/utils/analyticdb_spark_mock.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json + +from airflow.models import Connection + +ANALYTICDB_SPARK_PROJECT_ID_HOOK_UNIT_TEST = "example-project" + + +def mock_adb_spark_hook_default_project_id( + self, adb_spark_conn_id="mock_adb_spark_default", region="mock_region" +): + self.adb_spark_conn_id = adb_spark_conn_id + self.adb_spark_conn = Connection( + extra=json.dumps( + { + "auth_type": "AK", + "access_key_id": "mock_access_key_id", + "access_key_secret": "mock_access_key_secret", + "region": "mock_region", + } + ) + ) + self.region = region diff --git a/tests/system/providers/alibaba/example_adb_spark_batch.py b/tests/system/providers/alibaba/example_adb_spark_batch.py new file mode 100644 index 0000000000000..b6945190bc65f --- /dev/null +++ b/tests/system/providers/alibaba/example_adb_spark_batch.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +# Ignore missing args provided by default_args +# type: ignore[call-arg] +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.alibaba.cloud.operators.analyticdb_spark import AnalyticDBSparkBatchOperator + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +DAG_ID = "adb_spark_batch_dag" +# [START howto_operator_adb_spark_batch] +with DAG( + dag_id=DAG_ID, + start_date=datetime(2021, 1, 1), + default_args={"cluster_id": "your cluster", "rg_name": "your resource group", "region": "your region"}, + max_active_runs=1, + catchup=False, +) as dag: + + spark_pi = AnalyticDBSparkBatchOperator( + task_id="task1", + file="local:///tmp/spark-examples.jar", + class_name="org.apache.spark.examples.SparkPi", + ) + + spark_lr = AnalyticDBSparkBatchOperator( + task_id="task2", + file="local:///tmp/spark-examples.jar", + class_name="org.apache.spark.examples.SparkLR", + ) + + spark_pi >> spark_lr + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() +# [END howto_operator_adb_spark_batch] + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/tests/system/providers/alibaba/example_adb_spark_sql.py b/tests/system/providers/alibaba/example_adb_spark_sql.py new file mode 100644 index 0000000000000..851880fa7355b --- /dev/null +++ b/tests/system/providers/alibaba/example_adb_spark_sql.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +# Ignore missing args provided by default_args +# type: ignore[call-arg] +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.alibaba.cloud.operators.analyticdb_spark import AnalyticDBSparkSQLOperator + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +DAG_ID = "adb_spark_sql_dag" +# [START howto_operator_adb_spark_sql] +with DAG( + dag_id=DAG_ID, + start_date=datetime(2021, 1, 1), + default_args={"cluster_id": "your cluster", "rg_name": "your resource group", "region": "your region"}, + max_active_runs=1, + catchup=False, +) as dag: + + show_databases = AnalyticDBSparkSQLOperator(task_id="task1", sql="SHOE DATABASES;") + + show_tables = AnalyticDBSparkSQLOperator(task_id="task2", sql="SHOW TABLES;") + + show_databases >> show_tables + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() +# [END howto_operator_adb_spark_sql] + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) From 3b6e2d2b7484fda91e21840c2849693d256486a5 Mon Sep 17 00:00:00 2001 From: Qian Sun Date: Tue, 20 Jun 2023 17:01:29 +0800 Subject: [PATCH 2/8] revert db.py change --- airflow/utils/db.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 4b4bfed0adf78..a76f0d4f675d1 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -477,19 +477,6 @@ def create_default_connections(session: Session = NEW_SESSION): ), session, ) - merge_conn( - Connection( - conn_id="adb_spark_default", - conn_type="adb_spark", - extra="""{ - "auth_type": "AK", - "access_key_id": "", - "access_key_secret": "", - "region": ""} - """, - ), - session, - ) merge_conn( Connection( conn_id="pig_cli_default", From fffd8196fd3e7e1d791cebb9ffb6fd0432f3ec00 Mon Sep 17 00:00:00 2001 From: Qian Sun Date: Tue, 20 Jun 2023 18:18:01 +0800 Subject: [PATCH 3/8] address comments --- .../alibaba/cloud/hooks/analyticdb_spark.py | 9 ++------- .../alibaba/cloud/operators/analyticdb_spark.py | 1 - .../alibaba/cloud/sensors/analyticdb_spark.py | 2 -- docs/spelling_wordlist.txt | 2 +- .../alibaba/cloud/hooks/test_analyticdb_spark.py | 16 +++++++++++++--- .../cloud/operators/test_analyticdb_spark.py | 10 ++++++++-- .../cloud/sensors/test_analyticdb_spark.py | 3 +++ 7 files changed, 27 insertions(+), 16 deletions(-) diff --git a/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py b/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py index 39be927b28a3e..bf3eca172220b 100644 --- a/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py +++ b/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py @@ -176,9 +176,7 @@ def get_spark_log(self, app_id: str) -> str: ) except Exception as e: self.log.error(e) - raise AirflowException( - f"Errors when fetching log for spark application: {app_id}" - ) from e + raise AirflowException(f"Errors when fetching log for spark application: {app_id}") from e def kill_spark_app(self, app_id: str) -> None: """ @@ -209,7 +207,7 @@ def build_submit_app_data( name: str | None = None, ) -> dict: """ - Build the submit application request data + Build the submit application request data. :param file: path of the file containing the application to execute. :param class_name: name of the application Java/Spark main class. @@ -302,7 +300,6 @@ def _validate_list_of_stringables(vals: Sequence[str | int | float]) -> bool: Check the values in the provided list can be converted to strings. :param vals: list to validate - :return: true if valid """ if ( vals is None @@ -318,7 +315,6 @@ def _validate_extra_conf(conf: dict[Any, Any]) -> bool: Check configuration values are either strings or ints. :param conf: configuration variable - :return: true if valid """ if conf: if not isinstance(conf, dict): @@ -356,7 +352,6 @@ def get_adb_spark_client(self) -> Client: def get_default_region(self) -> str | None: """Get default region from connection.""" - extra_config = self.adb_spark_conn.extra_dejson auth_type = extra_config.get("auth_type", None) if not auth_type: diff --git a/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py b/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py index 8da771b67ee8e..cc0260ee681a8 100644 --- a/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py +++ b/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py @@ -23,7 +23,6 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator - from airflow.providers.alibaba.cloud.hooks.analyticdb_spark import AnalyticDBSparkHook, AppState if TYPE_CHECKING: diff --git a/airflow/providers/alibaba/cloud/sensors/analyticdb_spark.py b/airflow/providers/alibaba/cloud/sensors/analyticdb_spark.py index b0caa988858c0..24640a5b0cc7e 100644 --- a/airflow/providers/alibaba/cloud/sensors/analyticdb_spark.py +++ b/airflow/providers/alibaba/cloud/sensors/analyticdb_spark.py @@ -56,8 +56,6 @@ def __init__( def get_hook(self) -> AnalyticDBSparkHook: """ Get valid hook. - - :return: hook """ if self._adb_spark_hook is None or not isinstance(self._adb_spark_hook, AnalyticDBSparkHook): self._adb_spark_hook = AnalyticDBSparkHook( diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 6dee5c2ea2356..9a4ae4eaaa34c 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1357,8 +1357,8 @@ sourceArchiveUrl sourceRepository sourceUploadUrl Spark -sparkApplication sparkappinfo +sparkApplication sparkcmd SparkPi SparkR diff --git a/tests/providers/alibaba/cloud/hooks/test_analyticdb_spark.py b/tests/providers/alibaba/cloud/hooks/test_analyticdb_spark.py index dba4ee18b2a8c..bf38a3f7ca666 100644 --- a/tests/providers/alibaba/cloud/hooks/test_analyticdb_spark.py +++ b/tests/providers/alibaba/cloud/hooks/test_analyticdb_spark.py @@ -52,6 +52,7 @@ def setup_method(self): self.hook = AnalyticDBSparkHook(adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID) def test_build_submit_app_data(self): + """Test build submit application data for analyticDB spark as expected.""" res_data = self.hook.build_submit_app_data( file="oss://test_file", class_name="com.aliyun.spark.SparkPi", @@ -85,6 +86,7 @@ def test_build_submit_app_data(self): assert res_data == except_data def test_build_submit_sql_data(self): + """Test build submit sql data for analyticDB spark as expected.""" res_data = self.hook.build_submit_sql_data( sql=""" set spark.executor.instances=1; @@ -96,13 +98,16 @@ def test_build_submit_sql_data(self): num_executors=3, name="test", ) - except_data = "set spark.driver.resourceSpec = medium;set spark.executor.resourceSpec = medium;set " \ - "spark.executor.instances = 2;set spark.app.name = test;\n set " \ - "spark.executor.instances=1;\n show databases;" + except_data = ( + "set spark.driver.resourceSpec = medium;set spark.executor.resourceSpec = medium;set " + "spark.executor.instances = 2;set spark.app.name = test;\n set " + "spark.executor.instances=1;\n show databases;" + ) assert res_data == except_data @mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client")) def test_submit_spark_app(self, mock_service): + """Test submit_spark_app function works as expected.""" # Given mock_client = mock_service.return_value exists_method = mock_client.submit_spark_app @@ -117,6 +122,7 @@ def test_submit_spark_app(self, mock_service): @mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client")) def test_submit_spark_sql(self, mock_service): + """Test submit_spark_app function works as expected.""" # Given mock_client = mock_service.return_value exists_method = mock_client.submit_spark_app @@ -131,6 +137,7 @@ def test_submit_spark_sql(self, mock_service): @mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client")) def test_get_spark_state(self, mock_service): + """Test get_spark_state function works as expected.""" # Given mock_client = mock_service.return_value exists_method = mock_client.get_spark_app_state @@ -147,6 +154,7 @@ def test_get_spark_state(self, mock_service): @mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client")) def test_get_spark_web_ui_address(self, mock_service): + """Test get_spark_web_ui_address function works as expected.""" # Given mock_client = mock_service.return_value exists_method = mock_client.get_spark_app_web_ui_address @@ -165,6 +173,7 @@ def test_get_spark_web_ui_address(self, mock_service): @mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client")) def test_get_spark_log(self, mock_service): + """Test get_spark_log function works as expected.""" # Given mock_client = mock_service.return_value exists_method = mock_client.get_spark_app_log @@ -181,6 +190,7 @@ def test_get_spark_log(self, mock_service): @mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client")) def test_kill_spark_app(self, mock_service): + """Test kill_spark_app function works as expected.""" # Given mock_client = mock_service.return_value exists_method = mock_client.kill_spark_app diff --git a/tests/providers/alibaba/cloud/operators/test_analyticdb_spark.py b/tests/providers/alibaba/cloud/operators/test_analyticdb_spark.py index e5bfe65b1edff..e4eee6d46b146 100644 --- a/tests/providers/alibaba/cloud/operators/test_analyticdb_spark.py +++ b/tests/providers/alibaba/cloud/operators/test_analyticdb_spark.py @@ -23,8 +23,8 @@ from airflow import AirflowException from airflow.providers.alibaba.cloud.operators.analyticdb_spark import ( - AnalyticDBSparkBatchOperator, AnalyticDBSparkBaseOperator, + AnalyticDBSparkBatchOperator, AnalyticDBSparkSQLOperator, ) @@ -41,7 +41,6 @@ class TestAnalyticDBSparkBaseOperator: - def setup_method(self): self.operator = AnalyticDBSparkBaseOperator( adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, @@ -51,11 +50,13 @@ def setup_method(self): @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkHook")) def test_get_hook(self, mock_hook): + """Test get_hook function works as expected.""" self.operator.get_hook() mock_hook.assert_called_once_with(adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, region=MOCK_REGION) @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkBaseOperator.get_hook")) def test_poll_for_termination(self, mock_hook): + """Test poll_for_termination works as expected with COMPLETED application.""" # Given mock_hook.return_value.get_spark_state.return_value = "COMPLETED" @@ -64,6 +65,7 @@ def test_poll_for_termination(self, mock_hook): @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkBaseOperator.get_hook")) def test_poll_for_termination_with_exception(self, mock_hook): + """Test poll_for_termination raises AirflowException with FATAL application.""" # Given mock_hook.return_value.get_spark_state.return_value = "FATAL" @@ -75,6 +77,7 @@ def test_poll_for_termination_with_exception(self, mock_hook): class TestAnalyticDBSparkBatchOperator: @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkHook")) def test_execute(self, mock_hook): + """Test submit AnalyticDB Spark Batch Application works as expected.""" operator = AnalyticDBSparkBatchOperator( file=MOCK_FILE, cluster_id=MOCK_CLUSTER_ID, @@ -106,6 +109,7 @@ def test_execute(self, mock_hook): @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkBaseOperator.get_hook")) def test_execute_with_exception(self, mock_hook): + """Test submit AnalyticDB Spark Batch Application raises ValueError with invalid parameter.""" # Given mock_hook.return_value.submit_spark_app.side_effect = ValueError("List of strings expected") @@ -127,6 +131,7 @@ def test_execute_with_exception(self, mock_hook): class TestAnalyticDBSparklSQLOperator: @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkHook")) def test_execute(self, mock_hook): + """Test submit AnalyticDB Spark SQL Application works as expected.""" operator = AnalyticDBSparkSQLOperator( sql=MOCK_SQL, cluster_id=MOCK_CLUSTER_ID, @@ -152,6 +157,7 @@ def test_execute(self, mock_hook): @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkBaseOperator.get_hook")) def test_execute_with_exception(self, mock_hook): + """Test submit AnalyticDB Spark SQL Application raises ValueError with invalid parameter.""" # Given mock_hook.return_value.submit_spark_sql.side_effect = ValueError("List of strings expected") diff --git a/tests/providers/alibaba/cloud/sensors/test_analyticdb_spark.py b/tests/providers/alibaba/cloud/sensors/test_analyticdb_spark.py index cea0b625d887d..68df8bdbeacc0 100644 --- a/tests/providers/alibaba/cloud/sensors/test_analyticdb_spark.py +++ b/tests/providers/alibaba/cloud/sensors/test_analyticdb_spark.py @@ -41,11 +41,13 @@ def setup_method(self): @mock.patch(ADB_SPARK_SENSOR_STRING.format("AnalyticDBSparkHook")) def test_get_hook(self, mock_service): + """Test get_hook function works as expected.""" self.sensor.get_hook() mock_service.assert_called_once_with(adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, region=MOCK_REGION) @mock.patch(ADB_SPARK_SENSOR_STRING.format("AnalyticDBSparkSensor.get_hook")) def test_poke_terminal_state(self, mock_service): + """Test poke_terminal_state works as expected with COMPLETED application.""" # Given mock_service.return_value.get_spark_state.return_value = "COMPLETED" @@ -58,6 +60,7 @@ def test_poke_terminal_state(self, mock_service): @mock.patch(ADB_SPARK_SENSOR_STRING.format("AnalyticDBSparkSensor.get_hook")) def test_poke_non_terminal_state(self, mock_service): + """Test poke_terminal_state works as expected with RUNNING application.""" # Given mock_service.return_value.get_spark_state.return_value = "RUNNING" From 2896c0040d395dc9347eac7e2f65bd95f4e15ca5 Mon Sep 17 00:00:00 2001 From: Qian Sun Date: Tue, 20 Jun 2023 18:19:20 +0800 Subject: [PATCH 4/8] address comments --- airflow/providers/alibaba/cloud/operators/analyticdb_spark.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py b/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py index cc0260ee681a8..6ae88bb3d375c 100644 --- a/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py +++ b/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py @@ -54,8 +54,6 @@ def __init__( def get_hook(self) -> AnalyticDBSparkHook: """ Get valid hook. - - :return: hook """ if self._adb_spark_hook is None or not isinstance(self._adb_spark_hook, AnalyticDBSparkHook): self._adb_spark_hook = AnalyticDBSparkHook( From c9cd7b9ff69f6fb0b125a88cd377f3940ed3d3a5 Mon Sep 17 00:00:00 2001 From: Qian Sun Date: Wed, 21 Jun 2023 09:47:49 +0800 Subject: [PATCH 5/8] fix ut error --- .../alibaba/cloud/operators/test_analyticdb_spark.py | 8 ++++---- .../alibaba/cloud/sensors/test_analyticdb_spark.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/providers/alibaba/cloud/operators/test_analyticdb_spark.py b/tests/providers/alibaba/cloud/operators/test_analyticdb_spark.py index e4eee6d46b146..eb2db3ff39567 100644 --- a/tests/providers/alibaba/cloud/operators/test_analyticdb_spark.py +++ b/tests/providers/alibaba/cloud/operators/test_analyticdb_spark.py @@ -58,7 +58,7 @@ def test_get_hook(self, mock_hook): def test_poll_for_termination(self, mock_hook): """Test poll_for_termination works as expected with COMPLETED application.""" # Given - mock_hook.return_value.get_spark_state.return_value = "COMPLETED" + mock_hook.get_spark_state.return_value = "COMPLETED" # When self.operator.poll_for_termination(MOCK_APP_ID) @@ -67,7 +67,7 @@ def test_poll_for_termination(self, mock_hook): def test_poll_for_termination_with_exception(self, mock_hook): """Test poll_for_termination raises AirflowException with FATAL application.""" # Given - mock_hook.return_value.get_spark_state.return_value = "FATAL" + mock_hook.get_spark_state.return_value = "FATAL" # When with pytest.raises(AirflowException, match="Application mock_app_id did not succeed"): @@ -111,7 +111,7 @@ def test_execute(self, mock_hook): def test_execute_with_exception(self, mock_hook): """Test submit AnalyticDB Spark Batch Application raises ValueError with invalid parameter.""" # Given - mock_hook.return_value.submit_spark_app.side_effect = ValueError("List of strings expected") + mock_hook.submit_spark_app.side_effect = ValueError("List of strings expected") # When operator = AnalyticDBSparkBatchOperator( @@ -159,7 +159,7 @@ def test_execute(self, mock_hook): def test_execute_with_exception(self, mock_hook): """Test submit AnalyticDB Spark SQL Application raises ValueError with invalid parameter.""" # Given - mock_hook.return_value.submit_spark_sql.side_effect = ValueError("List of strings expected") + mock_hook.submit_spark_sql.side_effect = ValueError("List of strings expected") # When operator = AnalyticDBSparkSQLOperator( diff --git a/tests/providers/alibaba/cloud/sensors/test_analyticdb_spark.py b/tests/providers/alibaba/cloud/sensors/test_analyticdb_spark.py index 68df8bdbeacc0..8cef5175005f3 100644 --- a/tests/providers/alibaba/cloud/sensors/test_analyticdb_spark.py +++ b/tests/providers/alibaba/cloud/sensors/test_analyticdb_spark.py @@ -49,24 +49,24 @@ def test_get_hook(self, mock_service): def test_poke_terminal_state(self, mock_service): """Test poke_terminal_state works as expected with COMPLETED application.""" # Given - mock_service.return_value.get_spark_state.return_value = "COMPLETED" + mock_service.get_spark_state.return_value = "COMPLETED" # When res = self.sensor.poke(None) # Then assert res is True - mock_service.return_value.get_spark_state.assert_called_once_with(MOCK_ADB_SPARK_ID) + mock_service.get_spark_state.assert_called_once_with(MOCK_ADB_SPARK_ID) @mock.patch(ADB_SPARK_SENSOR_STRING.format("AnalyticDBSparkSensor.get_hook")) def test_poke_non_terminal_state(self, mock_service): """Test poke_terminal_state works as expected with RUNNING application.""" # Given - mock_service.return_value.get_spark_state.return_value = "RUNNING" + mock_service.get_spark_state.return_value = "RUNNING" # When res = self.sensor.poke(None) # Then assert res is False - mock_service.return_value.get_spark_state.assert_called_once_with(MOCK_ADB_SPARK_ID) + mock_service.get_spark_state.assert_called_once_with(MOCK_ADB_SPARK_ID) From 951b1256d32b0f1d7998f34567ba4dbb36132c02 Mon Sep 17 00:00:00 2001 From: Qian Sun Date: Wed, 21 Jun 2023 13:33:04 +0800 Subject: [PATCH 6/8] fix static check --- airflow/providers/alibaba/cloud/operators/analyticdb_spark.py | 4 +--- airflow/providers/alibaba/cloud/sensors/analyticdb_spark.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py b/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py index 6ae88bb3d375c..6ddd47dab28b0 100644 --- a/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py +++ b/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py @@ -52,9 +52,7 @@ def __init__( @cached_property def get_hook(self) -> AnalyticDBSparkHook: - """ - Get valid hook. - """ + """Get valid hook.""" if self._adb_spark_hook is None or not isinstance(self._adb_spark_hook, AnalyticDBSparkHook): self._adb_spark_hook = AnalyticDBSparkHook( adb_spark_conn_id=self._adb_spark_conn_id, region=self._region diff --git a/airflow/providers/alibaba/cloud/sensors/analyticdb_spark.py b/airflow/providers/alibaba/cloud/sensors/analyticdb_spark.py index 24640a5b0cc7e..fb6a962d43798 100644 --- a/airflow/providers/alibaba/cloud/sensors/analyticdb_spark.py +++ b/airflow/providers/alibaba/cloud/sensors/analyticdb_spark.py @@ -54,9 +54,7 @@ def __init__( @cached_property def get_hook(self) -> AnalyticDBSparkHook: - """ - Get valid hook. - """ + """Get valid hook.""" if self._adb_spark_hook is None or not isinstance(self._adb_spark_hook, AnalyticDBSparkHook): self._adb_spark_hook = AnalyticDBSparkHook( adb_spark_conn_id=self._adb_spark_conn_id, region=self._region From 91feaf2413a6a717c6343d7fa17a2997657ac949 Mon Sep 17 00:00:00 2001 From: Qian Sun Date: Wed, 21 Jun 2023 14:52:22 +0800 Subject: [PATCH 7/8] fix static check --- airflow/providers/alibaba/provider.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/alibaba/provider.yaml b/airflow/providers/alibaba/provider.yaml index c9ea17456a845..660dccdcd6483 100644 --- a/airflow/providers/alibaba/provider.yaml +++ b/airflow/providers/alibaba/provider.yaml @@ -49,7 +49,7 @@ integrations: - /docs/apache-airflow-providers-alibaba/operators/oss.rst tags: [alibaba] - integration-name: Alibaba Cloud AnalyticDB Spark - external-doc-url: https://www.alibabacloud.com/help/en/analyticdb-for-mysql/latest/spark-offline-application-development + external-doc-url: https://www.alibabacloud.com/help/en/analyticdb-for-mysql/latest/spark-developerment how-to-guide: - /docs/apache-airflow-providers-alibaba/operators/analyticdb_spark.rst tags: [alibaba] From a9ebc53765efd5be8b0e6ea3b2ba2baa4120e1b4 Mon Sep 17 00:00:00 2001 From: Qian Sun Date: Fri, 23 Jun 2023 20:38:03 +0800 Subject: [PATCH 8/8] fix index.rst --- docs/apache-airflow-providers-alibaba/index.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/apache-airflow-providers-alibaba/index.rst b/docs/apache-airflow-providers-alibaba/index.rst index 9712bab554f68..c050e7e9da753 100644 --- a/docs/apache-airflow-providers-alibaba/index.rst +++ b/docs/apache-airflow-providers-alibaba/index.rst @@ -83,12 +83,12 @@ for the minimum Airflow version supported) via Requirements ------------ -The minimum Apache Airflow version supported by this provider package is ``2.5.0``. +The minimum Apache Airflow version supported by this provider package is ``2.4.0``. ============================ ================== PIP package Version required ============================ ================== -``apache-airflow`` ``>=2.5.0`` +``apache-airflow`` ``>=2.4.0`` ``oss2`` ``>=2.14.0`` ``alibabacloud_adb20211201`` ``>=1.0.0`` ``alibabacloud_tea_openapi`` ``>=0.3.7``