diff --git a/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py b/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py new file mode 100644 index 0000000000000..bf3eca172220b --- /dev/null +++ b/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py @@ -0,0 +1,366 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from enum import Enum +from typing import Any, Sequence + +from alibabacloud_adb20211201.client import Client +from alibabacloud_adb20211201.models import ( + GetSparkAppLogRequest, + GetSparkAppStateRequest, + GetSparkAppWebUiAddressRequest, + KillSparkAppRequest, + SubmitSparkAppRequest, + SubmitSparkAppResponse, +) +from alibabacloud_tea_openapi.models import Config + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.utils.log.logging_mixin import LoggingMixin + + +class AppState(Enum): + """ + AnalyticDB Spark application states doc: + https://www.alibabacloud.com/help/en/analyticdb-for-mysql/latest/api-doc-adb-2021-12-01-api-struct + -sparkappinfo. + + """ + + SUBMITTED = "SUBMITTED" + STARTING = "STARTING" + RUNNING = "RUNNING" + FAILING = "FAILING" + FAILED = "FAILED" + KILLING = "KILLING" + KILLED = "KILLED" + SUCCEEDING = "SUCCEEDING" + COMPLETED = "COMPLETED" + FATAL = "FATAL" + UNKNOWN = "UNKNOWN" + + +class AnalyticDBSparkHook(BaseHook, LoggingMixin): + """ + Hook for AnalyticDB MySQL Spark through the REST API. + + :param adb_spark_conn_id: The Airflow connection used for AnalyticDB MySQL Spark credentials. + :param region: AnalyticDB MySQL region you want to submit spark application. + """ + + TERMINAL_STATES = {AppState.COMPLETED, AppState.FAILED, AppState.FATAL, AppState.KILLED} + + conn_name_attr = "alibabacloud_conn_id" + default_conn_name = "adb_spark_default" + conn_type = "adb_spark" + hook_name = "AnalyticDB Spark" + + def __init__( + self, adb_spark_conn_id: str = "adb_spark_default", region: str | None = None, *args, **kwargs + ) -> None: + self.adb_spark_conn_id = adb_spark_conn_id + self.adb_spark_conn = self.get_connection(adb_spark_conn_id) + self.region = self.get_default_region() if region is None else region + super().__init__(*args, **kwargs) + + def submit_spark_app( + self, cluster_id: str, rg_name: str, *args: Any, **kwargs: Any + ) -> SubmitSparkAppResponse: + """ + Perform request to submit spark application. + + :param cluster_id: The cluster ID of AnalyticDB MySQL 3.0 Data Lakehouse. + :param rg_name: The name of resource group in AnalyticDB MySQL 3.0 Data Lakehouse cluster. + """ + self.log.info("Submitting application") + request = SubmitSparkAppRequest( + dbcluster_id=cluster_id, + resource_group_name=rg_name, + data=json.dumps(self.build_submit_app_data(*args, **kwargs)), + app_type="BATCH", + ) + try: + return self.get_adb_spark_client().submit_spark_app(request) + except Exception as e: + self.log.error(e) + raise AirflowException("Errors when submit spark application") from e + + def submit_spark_sql( + self, cluster_id: str, rg_name: str, *args: Any, **kwargs: Any + ) -> SubmitSparkAppResponse: + """ + Perform request to submit spark sql. + + :param cluster_id: The cluster ID of AnalyticDB MySQL 3.0 Data Lakehouse. + :param rg_name: The name of resource group in AnalyticDB MySQL 3.0 Data Lakehouse cluster. + """ + self.log.info("Submitting Spark SQL") + request = SubmitSparkAppRequest( + dbcluster_id=cluster_id, + resource_group_name=rg_name, + data=self.build_submit_sql_data(*args, **kwargs), + app_type="SQL", + ) + try: + return self.get_adb_spark_client().submit_spark_app(request) + except Exception as e: + self.log.error(e) + raise AirflowException("Errors when submit spark sql") from e + + def get_spark_state(self, app_id: str) -> str: + """ + Fetch the state of the specified spark application. + + :param app_id: identifier of the spark application + """ + self.log.debug("Fetching state for spark application %s", app_id) + try: + return ( + self.get_adb_spark_client() + .get_spark_app_state(GetSparkAppStateRequest(app_id=app_id)) + .body.data.state + ) + except Exception as e: + self.log.error(e) + raise AirflowException(f"Errors when fetching state for spark application: {app_id}") from e + + def get_spark_web_ui_address(self, app_id: str) -> str: + """ + Fetch the web ui address of the specified spark application. + + :param app_id: identifier of the spark application + """ + self.log.debug("Fetching web ui address for spark application %s", app_id) + try: + return ( + self.get_adb_spark_client() + .get_spark_app_web_ui_address(GetSparkAppWebUiAddressRequest(app_id=app_id)) + .body.data.web_ui_address + ) + except Exception as e: + self.log.error(e) + raise AirflowException( + f"Errors when fetching web ui address for spark application: {app_id}" + ) from e + + def get_spark_log(self, app_id: str) -> str: + """ + Get the logs for a specified spark application. + + :param app_id: identifier of the spark application + """ + self.log.debug("Fetching log for spark application %s", app_id) + try: + return ( + self.get_adb_spark_client() + .get_spark_app_log(GetSparkAppLogRequest(app_id=app_id)) + .body.data.log_content + ) + except Exception as e: + self.log.error(e) + raise AirflowException(f"Errors when fetching log for spark application: {app_id}") from e + + def kill_spark_app(self, app_id: str) -> None: + """ + Kill the specified spark application. + + :param app_id: identifier of the spark application + """ + self.log.info("Killing spark application %s", app_id) + try: + self.get_adb_spark_client().kill_spark_app(KillSparkAppRequest(app_id=app_id)) + except Exception as e: + self.log.error(e) + raise AirflowException(f"Errors when killing spark application: {app_id}") from e + + @staticmethod + def build_submit_app_data( + file: str | None = None, + class_name: str | None = None, + args: Sequence[str | int | float] | None = None, + conf: dict[Any, Any] | None = None, + jars: Sequence[str] | None = None, + py_files: Sequence[str] | None = None, + files: Sequence[str] | None = None, + driver_resource_spec: str | None = None, + executor_resource_spec: str | None = None, + num_executors: int | str | None = None, + archives: Sequence[str] | None = None, + name: str | None = None, + ) -> dict: + """ + Build the submit application request data. + + :param file: path of the file containing the application to execute. + :param class_name: name of the application Java/Spark main class. + :param args: application command line arguments. + :param conf: Spark configuration properties. + :param jars: jars to be used in this application. + :param py_files: python files to be used in this application. + :param files: files to be used in this application. + :param driver_resource_spec: The resource specifications of the Spark driver. + :param executor_resource_spec: The resource specifications of each Spark executor. + :param num_executors: number of executors to launch for this application. + :param archives: archives to be used in this application. + :param name: name of this application. + """ + if file is None: + raise ValueError("Parameter file is need when submit spark application.") + + data: dict[str, Any] = {"file": file} + extra_conf: dict[str, str] = {} + + if class_name: + data["className"] = class_name + if args and AnalyticDBSparkHook._validate_list_of_stringables(args): + data["args"] = [str(val) for val in args] + if driver_resource_spec: + extra_conf["spark.driver.resourceSpec"] = driver_resource_spec + if executor_resource_spec: + extra_conf["spark.executor.resourceSpec"] = executor_resource_spec + if num_executors: + extra_conf["spark.executor.instances"] = str(num_executors) + data["conf"] = extra_conf.copy() + if conf and AnalyticDBSparkHook._validate_extra_conf(conf): + data["conf"].update(conf) + if jars and AnalyticDBSparkHook._validate_list_of_stringables(jars): + data["jars"] = jars + if py_files and AnalyticDBSparkHook._validate_list_of_stringables(py_files): + data["pyFiles"] = py_files + if files and AnalyticDBSparkHook._validate_list_of_stringables(files): + data["files"] = files + if archives and AnalyticDBSparkHook._validate_list_of_stringables(archives): + data["archives"] = archives + if name: + data["name"] = name + + return data + + @staticmethod + def build_submit_sql_data( + sql: str | None = None, + conf: dict[Any, Any] | None = None, + driver_resource_spec: str | None = None, + executor_resource_spec: str | None = None, + num_executors: int | str | None = None, + name: str | None = None, + ) -> str: + """ + Build the submit spark sql request data. + + :param sql: The SQL query to execute. (templated) + :param conf: Spark configuration properties. + :param driver_resource_spec: The resource specifications of the Spark driver. + :param executor_resource_spec: The resource specifications of each Spark executor. + :param num_executors: number of executors to launch for this application. + :param name: name of this application. + """ + if sql is None: + raise ValueError("Parameter sql is need when submit spark sql.") + + extra_conf: dict[str, str] = {} + formatted_conf = "" + + if driver_resource_spec: + extra_conf["spark.driver.resourceSpec"] = driver_resource_spec + if executor_resource_spec: + extra_conf["spark.executor.resourceSpec"] = executor_resource_spec + if num_executors: + extra_conf["spark.executor.instances"] = str(num_executors) + if name: + extra_conf["spark.app.name"] = name + if conf and AnalyticDBSparkHook._validate_extra_conf(conf): + extra_conf.update(conf) + for key, value in extra_conf.items(): + formatted_conf += f"set {key} = {value};" + + return (formatted_conf + sql).strip() + + @staticmethod + def _validate_list_of_stringables(vals: Sequence[str | int | float]) -> bool: + """ + Check the values in the provided list can be converted to strings. + + :param vals: list to validate + """ + if ( + vals is None + or not isinstance(vals, (tuple, list)) + or any(1 for val in vals if not isinstance(val, (str, int, float))) + ): + raise ValueError("List of strings expected") + return True + + @staticmethod + def _validate_extra_conf(conf: dict[Any, Any]) -> bool: + """ + Check configuration values are either strings or ints. + + :param conf: configuration variable + """ + if conf: + if not isinstance(conf, dict): + raise ValueError("'conf' argument must be a dict") + if any(True for k, v in conf.items() if not (v and isinstance(v, str) or isinstance(v, int))): + raise ValueError("'conf' values must be either strings or ints") + return True + + def get_adb_spark_client(self) -> Client: + """Get valid AnalyticDB MySQL Spark client.""" + assert self.region is not None + + extra_config = self.adb_spark_conn.extra_dejson + auth_type = extra_config.get("auth_type", None) + if not auth_type: + raise ValueError("No auth_type specified in extra_config.") + + if auth_type != "AK": + raise ValueError(f"Unsupported auth_type: {auth_type}") + adb_spark_access_key_id = extra_config.get("access_key_id", None) + adb_spark_access_secret = extra_config.get("access_key_secret", None) + if not adb_spark_access_key_id: + raise ValueError(f"No access_key_id is specified for connection: {self.adb_spark_conn_id}") + + if not adb_spark_access_secret: + raise ValueError(f"No access_key_secret is specified for connection: {self.adb_spark_conn_id}") + + return Client( + Config( + access_key_id=adb_spark_access_key_id, + access_key_secret=adb_spark_access_secret, + endpoint=f"adb.{self.region}.aliyuncs.com", + ) + ) + + def get_default_region(self) -> str | None: + """Get default region from connection.""" + extra_config = self.adb_spark_conn.extra_dejson + auth_type = extra_config.get("auth_type", None) + if not auth_type: + raise ValueError("No auth_type specified in extra_config. ") + + if auth_type != "AK": + raise ValueError(f"Unsupported auth_type: {auth_type}") + + default_region = extra_config.get("region", None) + if not default_region: + raise ValueError(f"No region is specified for connection: {self.adb_spark_conn}") + return default_region diff --git a/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py b/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py new file mode 100644 index 0000000000000..6ddd47dab28b0 --- /dev/null +++ b/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py @@ -0,0 +1,223 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from functools import cached_property +from time import sleep +from typing import TYPE_CHECKING, Any, Sequence + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.alibaba.cloud.hooks.analyticdb_spark import AnalyticDBSparkHook, AppState + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class AnalyticDBSparkBaseOperator(BaseOperator): + """Abstract base class that defines how users develop AnalyticDB Spark.""" + + def __init__( + self, + *, + adb_spark_conn_id: str = "adb_spark_default", + region: str | None = None, + polling_interval: int = 0, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + self.app_id: str | None = None + self.polling_interval = polling_interval + + self._adb_spark_conn_id = adb_spark_conn_id + self._region = region + + self._adb_spark_hook: AnalyticDBSparkHook | None = None + + @cached_property + def get_hook(self) -> AnalyticDBSparkHook: + """Get valid hook.""" + if self._adb_spark_hook is None or not isinstance(self._adb_spark_hook, AnalyticDBSparkHook): + self._adb_spark_hook = AnalyticDBSparkHook( + adb_spark_conn_id=self._adb_spark_conn_id, region=self._region + ) + return self._adb_spark_hook + + def execute(self, context: Context) -> Any: + ... + + def monitor_application(self): + self.log.info("Monitoring application with %s", self.app_id) + + if self.polling_interval > 0: + self.poll_for_termination(self.app_id) + + def poll_for_termination(self, app_id: str) -> None: + """ + Pool for spark application termination. + + :param app_id: id of the spark application to monitor + """ + hook = self.get_hook + state = hook.get_spark_state(app_id) + while AppState(state) not in AnalyticDBSparkHook.TERMINAL_STATES: + self.log.debug("Application with id %s is in state: %s", app_id, state) + sleep(self.polling_interval) + state = hook.get_spark_state(app_id) + self.log.info("Application with id %s terminated with state: %s", app_id, state) + self.log.info( + "Web ui address is %s for application with id %s", hook.get_spark_web_ui_address(app_id), app_id + ) + self.log.info(hook.get_spark_log(app_id)) + if AppState(state) != AppState.COMPLETED: + raise AirflowException(f"Application {app_id} did not succeed") + + def on_kill(self) -> None: + self.kill() + + def kill(self) -> None: + """Delete the specified application.""" + if self.app_id is not None: + self.get_hook.kill_spark_app(self.app_id) + + +class AnalyticDBSparkSQLOperator(AnalyticDBSparkBaseOperator): + """ + This operator warps the AnalyticDB Spark REST API, allowing to submit a Spark sql + application to the underlying cluster. + + :param sql: The SQL query to execute. + :param conf: Spark configuration properties. + :param driver_resource_spec: The resource specifications of the Spark driver. + :param executor_resource_spec: The resource specifications of each Spark executor. + :param num_executors: number of executors to launch for this application. + :param name: name of this application. + :param cluster_id: The cluster ID of AnalyticDB MySQL 3.0 Data Lakehouse. + :param rg_name: The name of resource group in AnalyticDB MySQL 3.0 Data Lakehouse cluster. + """ + + template_fields: Sequence[str] = ("spark_params",) + template_fields_renderers = {"spark_params": "json"} + + def __init__( + self, + *, + sql: str, + conf: dict[Any, Any] | None = None, + driver_resource_spec: str | None = None, + executor_resource_spec: str | None = None, + num_executors: int | str | None = None, + name: str | None = None, + cluster_id: str, + rg_name: str, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + self.spark_params = { + "sql": sql, + "conf": conf, + "driver_resource_spec": driver_resource_spec, + "executor_resource_spec": executor_resource_spec, + "num_executors": num_executors, + "name": name, + } + + self._cluster_id = cluster_id + self._rg_name = rg_name + + def execute(self, context: Context) -> Any: + submit_response = self.get_hook.submit_spark_sql( + cluster_id=self._cluster_id, rg_name=self._rg_name, **self.spark_params + ) + self.app_id = submit_response.body.data.app_id + self.monitor_application() + return self.app_id + + +class AnalyticDBSparkBatchOperator(AnalyticDBSparkBaseOperator): + """ + This operator warps the AnalyticDB Spark REST API, allowing to submit a Spark batch + application to the underlying cluster. + + :param file: path of the file containing the application to execute. + :param class_name: name of the application Java/Spark main class. + :param args: application command line arguments. + :param conf: Spark configuration properties. + :param jars: jars to be used in this application. + :param py_files: python files to be used in this application. + :param files: files to be used in this application. + :param driver_resource_spec: The resource specifications of the Spark driver. + :param executor_resource_spec: The resource specifications of each Spark executor. + :param num_executors: number of executors to launch for this application. + :param archives: archives to be used in this application. + :param name: name of this application. + :param cluster_id: The cluster ID of AnalyticDB MySQL 3.0 Data Lakehouse. + :param rg_name: The name of resource group in AnalyticDB MySQL 3.0 Data Lakehouse cluster. + """ + + template_fields: Sequence[str] = ("spark_params",) + template_fields_renderers = {"spark_params": "json"} + + def __init__( + self, + *, + file: str, + class_name: str | None = None, + args: Sequence[str | int | float] | None = None, + conf: dict[Any, Any] | None = None, + jars: Sequence[str] | None = None, + py_files: Sequence[str] | None = None, + files: Sequence[str] | None = None, + driver_resource_spec: str | None = None, + executor_resource_spec: str | None = None, + num_executors: int | str | None = None, + archives: Sequence[str] | None = None, + name: str | None = None, + cluster_id: str, + rg_name: str, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + self.spark_params = { + "file": file, + "class_name": class_name, + "args": args, + "conf": conf, + "jars": jars, + "py_files": py_files, + "files": files, + "driver_resource_spec": driver_resource_spec, + "executor_resource_spec": executor_resource_spec, + "num_executors": num_executors, + "archives": archives, + "name": name, + } + + self._cluster_id = cluster_id + self._rg_name = rg_name + + def execute(self, context: Context) -> Any: + submit_response = self.get_hook.submit_spark_app( + cluster_id=self._cluster_id, rg_name=self._rg_name, **self.spark_params + ) + self.app_id = submit_response.body.data.app_id + self.monitor_application() + return self.app_id diff --git a/airflow/providers/alibaba/cloud/sensors/analyticdb_spark.py b/airflow/providers/alibaba/cloud/sensors/analyticdb_spark.py new file mode 100644 index 0000000000000..fb6a962d43798 --- /dev/null +++ b/airflow/providers/alibaba/cloud/sensors/analyticdb_spark.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from functools import cached_property +from typing import TYPE_CHECKING, Any, Sequence + +from airflow.providers.alibaba.cloud.hooks.analyticdb_spark import AnalyticDBSparkHook, AppState +from airflow.sensors.base import BaseSensorOperator + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class AnalyticDBSparkSensor(BaseSensorOperator): + """ + Monitor a AnalyticDB Spark session for termination. + + :param app_id: identifier of the monitored app depends on the option that's being modified. + :param adb_spark_conn_id: reference to a pre-defined ADB Spark connection. + :param region: AnalyticDB MySQL region you want to submit spark application. + """ + + template_fields: Sequence[str] = ("app_id",) + + def __init__( + self, + *, + app_id: str, + adb_spark_conn_id: str = "adb_spark_default", + region: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.app_id = app_id + self._region = region + self._adb_spark_conn_id = adb_spark_conn_id + self._adb_spark_hook: AnalyticDBSparkHook | None = None + + @cached_property + def get_hook(self) -> AnalyticDBSparkHook: + """Get valid hook.""" + if self._adb_spark_hook is None or not isinstance(self._adb_spark_hook, AnalyticDBSparkHook): + self._adb_spark_hook = AnalyticDBSparkHook( + adb_spark_conn_id=self._adb_spark_conn_id, region=self._region + ) + return self._adb_spark_hook + + def poke(self, context: Context) -> bool: + app_id = self.app_id + + state = self.get_hook.get_spark_state(app_id) + return AppState(state) in AnalyticDBSparkHook.TERMINAL_STATES diff --git a/airflow/providers/alibaba/provider.yaml b/airflow/providers/alibaba/provider.yaml index aa2d99897189a..660dccdcd6483 100644 --- a/airflow/providers/alibaba/provider.yaml +++ b/airflow/providers/alibaba/provider.yaml @@ -38,6 +38,8 @@ versions: dependencies: - apache-airflow>=2.4.0 - oss2>=2.14.0 + - alibabacloud_adb20211201>=1.0.0 + - alibabacloud_tea_openapi>=0.3.7 integrations: - integration-name: Alibaba Cloud OSS @@ -46,26 +48,42 @@ integrations: how-to-guide: - /docs/apache-airflow-providers-alibaba/operators/oss.rst tags: [alibaba] + - integration-name: Alibaba Cloud AnalyticDB Spark + external-doc-url: https://www.alibabacloud.com/help/en/analyticdb-for-mysql/latest/spark-developerment + how-to-guide: + - /docs/apache-airflow-providers-alibaba/operators/analyticdb_spark.rst + tags: [alibaba] operators: - integration-name: Alibaba Cloud OSS python-modules: - airflow.providers.alibaba.cloud.operators.oss + - integration-name: Alibaba Cloud AnalyticDB Spark + python-modules: + - airflow.providers.alibaba.cloud.operators.analyticdb_spark sensors: - integration-name: Alibaba Cloud OSS python-modules: - airflow.providers.alibaba.cloud.sensors.oss_key + - integration-name: Alibaba Cloud AnalyticDB Spark + python-modules: + - airflow.providers.alibaba.cloud.sensors.analyticdb_spark hooks: - integration-name: Alibaba Cloud OSS python-modules: - airflow.providers.alibaba.cloud.hooks.oss + - integration-name: Alibaba Cloud AnalyticDB Spark + python-modules: + - airflow.providers.alibaba.cloud.hooks.analyticdb_spark connection-types: - hook-class-name: airflow.providers.alibaba.cloud.hooks.oss.OSSHook connection-type: oss + - hook-class-name: airflow.providers.alibaba.cloud.hooks.analyticdb_spark.AnalyticDBSparkHook + connection-type: adb_spark logging: - airflow.providers.alibaba.cloud.log.oss_task_handler.OSSTaskHandler diff --git a/docs/apache-airflow-providers-alibaba/connections/alibaba.rst b/docs/apache-airflow-providers-alibaba/connections/alibaba.rst index d697b4dc0c753..4cf4747d7e2a5 100644 --- a/docs/apache-airflow-providers-alibaba/connections/alibaba.rst +++ b/docs/apache-airflow-providers-alibaba/connections/alibaba.rst @@ -26,7 +26,7 @@ Authentication may be performed using `Security Token Service (STS) or a signed Default Connection IDs ---------------------- -The default connection ID is ``oss_default``. +The default connection IDs are ``oss_default`` and ``adb_spark_default``. Configuring the Connection -------------------------- diff --git a/docs/apache-airflow-providers-alibaba/index.rst b/docs/apache-airflow-providers-alibaba/index.rst index 8e0da17c71fdd..c050e7e9da753 100644 --- a/docs/apache-airflow-providers-alibaba/index.rst +++ b/docs/apache-airflow-providers-alibaba/index.rst @@ -85,11 +85,13 @@ Requirements The minimum Apache Airflow version supported by this provider package is ``2.4.0``. -================== ================== -PIP package Version required -================== ================== -``apache-airflow`` ``>=2.4.0`` -``oss2`` ``>=2.14.0`` -================== ================== +============================ ================== +PIP package Version required +============================ ================== +``apache-airflow`` ``>=2.4.0`` +``oss2`` ``>=2.14.0`` +``alibabacloud_adb20211201`` ``>=1.0.0`` +``alibabacloud_tea_openapi`` ``>=0.3.7`` +============================ ================== .. include:: ../../airflow/providers/alibaba/CHANGELOG.rst diff --git a/docs/apache-airflow-providers-alibaba/operators/analyticdb_spark.rst b/docs/apache-airflow-providers-alibaba/operators/analyticdb_spark.rst new file mode 100644 index 0000000000000..ac3f0638ad24c --- /dev/null +++ b/docs/apache-airflow-providers-alibaba/operators/analyticdb_spark.rst @@ -0,0 +1,45 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Alibaba Cloud AnalyticDB Spark Operators +======================================== + +Overview +-------- + +Airflow to Alibaba Cloud AnalyticDB Spark integration provides several operators to develop spark batch and sql applications. + + - :class:`~airflow.providers.alibaba.cloud.operators.analyticdb_spark.AnalyticDBSparkBatchOperator` + - :class:`~airflow.providers.alibaba.cloud.operators.analyticdb_spark.AnalyticDBSparkSQLOperator` + +Develop Spark batch applications +------------------------------------------- + +Purpose +""""""" + +This example dag uses ``AnalyticDBSparkBatchOperator`` to submit Spark Pi and Spark Logistic regression applications. + +Defining tasks +"""""""""""""" + +In the following code we submit Spark Pi and Spark Logistic regression applications. + +.. exampleinclude:: /../../tests/system/providers/alibaba/example_adb_spark_batch.py + :language: python + :start-after: [START howto_operator_adb_spark_batch] + :end-before: [END howto_operator_adb_spark_batch] diff --git a/docs/conf.py b/docs/conf.py index 27b3e07384d90..f45956c9c0c88 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -588,6 +588,8 @@ def _get_params(root_schema: dict, prefix: str = "", default_section: str = "") autodoc_mock_imports = [ "MySQLdb", "adal", + "alibabacloud_adb20211201", + "alibabacloud_tea_openapi", "analytics", "azure", "azure.cosmos", diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index babb506f4108f..9a4ae4eaaa34c 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -861,6 +861,7 @@ kwargs KYLIN Kylin kylin +Lakehouse LanguageServiceClient lastname latencies @@ -1356,6 +1357,7 @@ sourceArchiveUrl sourceRepository sourceUploadUrl Spark +sparkappinfo sparkApplication sparkcmd SparkPi diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index d77fd32e8d46d..a25786120f6c9 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -11,6 +11,8 @@ }, "alibaba": { "deps": [ + "alibabacloud_adb20211201>=1.0.0", + "alibabacloud_tea_openapi>=0.3.7", "apache-airflow>=2.4.0", "oss2>=2.14.0" ], diff --git a/tests/providers/alibaba/cloud/hooks/test_analyticdb_spark.py b/tests/providers/alibaba/cloud/hooks/test_analyticdb_spark.py new file mode 100644 index 0000000000000..bf38a3f7ca666 --- /dev/null +++ b/tests/providers/alibaba/cloud/hooks/test_analyticdb_spark.py @@ -0,0 +1,203 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from alibabacloud_adb20211201.models import ( + GetSparkAppLogResponse, + GetSparkAppLogResponseBody, + GetSparkAppLogResponseBodyData, + GetSparkAppStateResponse, + GetSparkAppStateResponseBody, + GetSparkAppStateResponseBodyData, + GetSparkAppWebUiAddressResponse, + GetSparkAppWebUiAddressResponseBody, + GetSparkAppWebUiAddressResponseBodyData, + KillSparkAppResponse, + SubmitSparkAppResponse, +) + +from airflow.providers.alibaba.cloud.hooks.analyticdb_spark import AnalyticDBSparkHook +from tests.providers.alibaba.cloud.utils.analyticdb_spark_mock import mock_adb_spark_hook_default_project_id + +ADB_SPARK_STRING = "airflow.providers.alibaba.cloud.hooks.analyticdb_spark.{}" +MOCK_ADB_SPARK_CONN_ID = "mock_id" +MOCK_ADB_CLUSTER_ID = "mock_adb_cluster_id" +MOCK_ADB_RG_NAME = "mock_adb_rg_name" +MOCK_ADB_SPARK_ID = "mock_adb_spark_id" + + +class TestAnalyticDBSparkHook: + def setup_method(self): + with mock.patch( + ADB_SPARK_STRING.format("AnalyticDBSparkHook.__init__"), + new=mock_adb_spark_hook_default_project_id, + ): + self.hook = AnalyticDBSparkHook(adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID) + + def test_build_submit_app_data(self): + """Test build submit application data for analyticDB spark as expected.""" + res_data = self.hook.build_submit_app_data( + file="oss://test_file", + class_name="com.aliyun.spark.SparkPi", + args=[1000, "test-args"], + conf={"spark.executor.instances": 1, "spark.eventLog.enabled": "true"}, + jars=["oss://1.jar", "oss://2.jar"], + py_files=["oss://1.py", "oss://2.py"], + files=["oss://1.file", "oss://2.file"], + driver_resource_spec="medium", + executor_resource_spec="medium", + num_executors=2, + archives=["oss://1.zip", "oss://2.zip"], + name="test", + ) + except_data = { + "file": "oss://test_file", + "className": "com.aliyun.spark.SparkPi", + "args": ["1000", "test-args"], + "conf": { + "spark.executor.instances": 1, + "spark.eventLog.enabled": "true", + "spark.driver.resourceSpec": "medium", + "spark.executor.resourceSpec": "medium", + }, + "jars": ["oss://1.jar", "oss://2.jar"], + "pyFiles": ["oss://1.py", "oss://2.py"], + "files": ["oss://1.file", "oss://2.file"], + "archives": ["oss://1.zip", "oss://2.zip"], + "name": "test", + } + assert res_data == except_data + + def test_build_submit_sql_data(self): + """Test build submit sql data for analyticDB spark as expected.""" + res_data = self.hook.build_submit_sql_data( + sql=""" + set spark.executor.instances=1; + show databases; + """, + conf={"spark.executor.instances": 2}, + driver_resource_spec="medium", + executor_resource_spec="medium", + num_executors=3, + name="test", + ) + except_data = ( + "set spark.driver.resourceSpec = medium;set spark.executor.resourceSpec = medium;set " + "spark.executor.instances = 2;set spark.app.name = test;\n set " + "spark.executor.instances=1;\n show databases;" + ) + assert res_data == except_data + + @mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client")) + def test_submit_spark_app(self, mock_service): + """Test submit_spark_app function works as expected.""" + # Given + mock_client = mock_service.return_value + exists_method = mock_client.submit_spark_app + exists_method.return_value = SubmitSparkAppResponse(status_code=200) + + # When + res = self.hook.submit_spark_app(MOCK_ADB_CLUSTER_ID, MOCK_ADB_RG_NAME, "oss://test.py") + + # Then + assert isinstance(res, SubmitSparkAppResponse) + mock_service.assert_called_once_with() + + @mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client")) + def test_submit_spark_sql(self, mock_service): + """Test submit_spark_app function works as expected.""" + # Given + mock_client = mock_service.return_value + exists_method = mock_client.submit_spark_app + exists_method.return_value = SubmitSparkAppResponse(status_code=200) + + # When + res = self.hook.submit_spark_sql(MOCK_ADB_CLUSTER_ID, MOCK_ADB_RG_NAME, "SELECT 1") + + # Then + assert isinstance(res, SubmitSparkAppResponse) + mock_service.assert_called_once_with() + + @mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client")) + def test_get_spark_state(self, mock_service): + """Test get_spark_state function works as expected.""" + # Given + mock_client = mock_service.return_value + exists_method = mock_client.get_spark_app_state + exists_method.return_value = GetSparkAppStateResponse( + body=GetSparkAppStateResponseBody(data=GetSparkAppStateResponseBodyData(state="RUNNING")) + ) + + # When + res = self.hook.get_spark_state(MOCK_ADB_SPARK_ID) + + # Then + assert res == "RUNNING" + mock_service.assert_called_once_with() + + @mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client")) + def test_get_spark_web_ui_address(self, mock_service): + """Test get_spark_web_ui_address function works as expected.""" + # Given + mock_client = mock_service.return_value + exists_method = mock_client.get_spark_app_web_ui_address + exists_method.return_value = GetSparkAppWebUiAddressResponse( + body=GetSparkAppWebUiAddressResponseBody( + data=GetSparkAppWebUiAddressResponseBodyData(web_ui_address="https://mock-web-ui-address.com") + ) + ) + + # When + res = self.hook.get_spark_web_ui_address(MOCK_ADB_SPARK_ID) + + # Then + assert res == "https://mock-web-ui-address.com" + mock_service.assert_called_once_with() + + @mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client")) + def test_get_spark_log(self, mock_service): + """Test get_spark_log function works as expected.""" + # Given + mock_client = mock_service.return_value + exists_method = mock_client.get_spark_app_log + exists_method.return_value = GetSparkAppLogResponse( + body=GetSparkAppLogResponseBody(data=GetSparkAppLogResponseBodyData(log_content="Pi is 3.14")) + ) + + # When + res = self.hook.get_spark_log(MOCK_ADB_SPARK_ID) + + # Then + assert res == "Pi is 3.14" + mock_service.assert_called_once_with() + + @mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client")) + def test_kill_spark_app(self, mock_service): + """Test kill_spark_app function works as expected.""" + # Given + mock_client = mock_service.return_value + exists_method = mock_client.kill_spark_app + exists_method.return_value = KillSparkAppResponse() + + # When + self.hook.kill_spark_app(MOCK_ADB_SPARK_ID) + + # Then + mock_service.assert_called_once_with() diff --git a/tests/providers/alibaba/cloud/operators/test_analyticdb_spark.py b/tests/providers/alibaba/cloud/operators/test_analyticdb_spark.py new file mode 100644 index 0000000000000..eb2db3ff39567 --- /dev/null +++ b/tests/providers/alibaba/cloud/operators/test_analyticdb_spark.py @@ -0,0 +1,176 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow import AirflowException +from airflow.providers.alibaba.cloud.operators.analyticdb_spark import ( + AnalyticDBSparkBaseOperator, + AnalyticDBSparkBatchOperator, + AnalyticDBSparkSQLOperator, +) + +ADB_SPARK_OPERATOR_STRING = "airflow.providers.alibaba.cloud.operators.analyticdb_spark.{}" + +MOCK_FILE = "oss://test.py" +MOCK_CLUSTER_ID = "mock_cluster_id" +MOCK_RG_NAME = "mock_rg_name" +MOCK_ADB_SPARK_CONN_ID = "mock_adb_spark_conn_id" +MOCK_REGION = "mock_region" +MOCK_TASK_ID = "mock_task_id" +MOCK_APP_ID = "mock_app_id" +MOCK_SQL = "SELECT 1;" + + +class TestAnalyticDBSparkBaseOperator: + def setup_method(self): + self.operator = AnalyticDBSparkBaseOperator( + adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, + region=MOCK_REGION, + task_id=MOCK_TASK_ID, + ) + + @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkHook")) + def test_get_hook(self, mock_hook): + """Test get_hook function works as expected.""" + self.operator.get_hook() + mock_hook.assert_called_once_with(adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, region=MOCK_REGION) + + @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkBaseOperator.get_hook")) + def test_poll_for_termination(self, mock_hook): + """Test poll_for_termination works as expected with COMPLETED application.""" + # Given + mock_hook.get_spark_state.return_value = "COMPLETED" + + # When + self.operator.poll_for_termination(MOCK_APP_ID) + + @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkBaseOperator.get_hook")) + def test_poll_for_termination_with_exception(self, mock_hook): + """Test poll_for_termination raises AirflowException with FATAL application.""" + # Given + mock_hook.get_spark_state.return_value = "FATAL" + + # When + with pytest.raises(AirflowException, match="Application mock_app_id did not succeed"): + self.operator.poll_for_termination(MOCK_APP_ID) + + +class TestAnalyticDBSparkBatchOperator: + @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkHook")) + def test_execute(self, mock_hook): + """Test submit AnalyticDB Spark Batch Application works as expected.""" + operator = AnalyticDBSparkBatchOperator( + file=MOCK_FILE, + cluster_id=MOCK_CLUSTER_ID, + rg_name=MOCK_RG_NAME, + adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, + region=MOCK_REGION, + task_id=MOCK_TASK_ID, + ) + + operator.execute(None) + + mock_hook.assert_called_once_with(adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, region=MOCK_REGION) + mock_hook.return_value.submit_spark_app.assert_called_once_with( + cluster_id=MOCK_CLUSTER_ID, + rg_name=MOCK_RG_NAME, + file=MOCK_FILE, + class_name=None, + args=None, + conf=None, + jars=None, + py_files=None, + files=None, + driver_resource_spec=None, + executor_resource_spec=None, + num_executors=None, + archives=None, + name=None, + ) + + @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkBaseOperator.get_hook")) + def test_execute_with_exception(self, mock_hook): + """Test submit AnalyticDB Spark Batch Application raises ValueError with invalid parameter.""" + # Given + mock_hook.submit_spark_app.side_effect = ValueError("List of strings expected") + + # When + operator = AnalyticDBSparkBatchOperator( + file=MOCK_FILE, + args=(True, False), + cluster_id=MOCK_CLUSTER_ID, + rg_name=MOCK_RG_NAME, + adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, + region=MOCK_REGION, + task_id=MOCK_TASK_ID, + ) + + with pytest.raises(ValueError, match="List of strings expected"): + operator.execute(None) + + +class TestAnalyticDBSparklSQLOperator: + @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkHook")) + def test_execute(self, mock_hook): + """Test submit AnalyticDB Spark SQL Application works as expected.""" + operator = AnalyticDBSparkSQLOperator( + sql=MOCK_SQL, + cluster_id=MOCK_CLUSTER_ID, + rg_name=MOCK_RG_NAME, + adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, + region=MOCK_REGION, + task_id=MOCK_TASK_ID, + ) + + operator.execute(None) + + mock_hook.assert_called_once_with(adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, region=MOCK_REGION) + mock_hook.return_value.submit_spark_sql.assert_called_once_with( + cluster_id=MOCK_CLUSTER_ID, + rg_name=MOCK_RG_NAME, + sql=MOCK_SQL, + conf=None, + driver_resource_spec=None, + executor_resource_spec=None, + num_executors=None, + name=None, + ) + + @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkBaseOperator.get_hook")) + def test_execute_with_exception(self, mock_hook): + """Test submit AnalyticDB Spark SQL Application raises ValueError with invalid parameter.""" + # Given + mock_hook.submit_spark_sql.side_effect = ValueError("List of strings expected") + + # When + operator = AnalyticDBSparkSQLOperator( + sql=MOCK_SQL, + conf={"spark.eventLog.enabled": True}, + cluster_id=MOCK_CLUSTER_ID, + rg_name=MOCK_RG_NAME, + adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, + region=MOCK_REGION, + task_id=MOCK_TASK_ID, + ) + + with pytest.raises(ValueError, match="List of strings expected"): + operator.execute(None) diff --git a/tests/providers/alibaba/cloud/sensors/test_analyticdb_spark.py b/tests/providers/alibaba/cloud/sensors/test_analyticdb_spark.py new file mode 100644 index 0000000000000..8cef5175005f3 --- /dev/null +++ b/tests/providers/alibaba/cloud/sensors/test_analyticdb_spark.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from airflow.providers.alibaba.cloud.sensors.analyticdb_spark import AnalyticDBSparkSensor +from airflow.utils import timezone + +ADB_SPARK_SENSOR_STRING = "airflow.providers.alibaba.cloud.sensors.analyticdb_spark.{}" +DEFAULT_DATE = timezone.datetime(2017, 1, 1) +MOCK_ADB_SPARK_CONN_ID = "mock_adb_spark_default" +MOCK_ADB_SPARK_ID = "mock_adb_spark_id" +MOCK_SENSOR_TASK_ID = "test-adb-spark-operator" +MOCK_REGION = "mock_region" + + +class TestAnalyticDBSparkSensor: + def setup_method(self): + self.sensor = AnalyticDBSparkSensor( + app_id=MOCK_ADB_SPARK_ID, + adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, + region=MOCK_REGION, + task_id=MOCK_SENSOR_TASK_ID, + ) + + @mock.patch(ADB_SPARK_SENSOR_STRING.format("AnalyticDBSparkHook")) + def test_get_hook(self, mock_service): + """Test get_hook function works as expected.""" + self.sensor.get_hook() + mock_service.assert_called_once_with(adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID, region=MOCK_REGION) + + @mock.patch(ADB_SPARK_SENSOR_STRING.format("AnalyticDBSparkSensor.get_hook")) + def test_poke_terminal_state(self, mock_service): + """Test poke_terminal_state works as expected with COMPLETED application.""" + # Given + mock_service.get_spark_state.return_value = "COMPLETED" + + # When + res = self.sensor.poke(None) + + # Then + assert res is True + mock_service.get_spark_state.assert_called_once_with(MOCK_ADB_SPARK_ID) + + @mock.patch(ADB_SPARK_SENSOR_STRING.format("AnalyticDBSparkSensor.get_hook")) + def test_poke_non_terminal_state(self, mock_service): + """Test poke_terminal_state works as expected with RUNNING application.""" + # Given + mock_service.get_spark_state.return_value = "RUNNING" + + # When + res = self.sensor.poke(None) + + # Then + assert res is False + mock_service.get_spark_state.assert_called_once_with(MOCK_ADB_SPARK_ID) diff --git a/tests/providers/alibaba/cloud/utils/analyticdb_spark_mock.py b/tests/providers/alibaba/cloud/utils/analyticdb_spark_mock.py new file mode 100644 index 0000000000000..b43cc1feb2c4c --- /dev/null +++ b/tests/providers/alibaba/cloud/utils/analyticdb_spark_mock.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json + +from airflow.models import Connection + +ANALYTICDB_SPARK_PROJECT_ID_HOOK_UNIT_TEST = "example-project" + + +def mock_adb_spark_hook_default_project_id( + self, adb_spark_conn_id="mock_adb_spark_default", region="mock_region" +): + self.adb_spark_conn_id = adb_spark_conn_id + self.adb_spark_conn = Connection( + extra=json.dumps( + { + "auth_type": "AK", + "access_key_id": "mock_access_key_id", + "access_key_secret": "mock_access_key_secret", + "region": "mock_region", + } + ) + ) + self.region = region diff --git a/tests/system/providers/alibaba/example_adb_spark_batch.py b/tests/system/providers/alibaba/example_adb_spark_batch.py new file mode 100644 index 0000000000000..b6945190bc65f --- /dev/null +++ b/tests/system/providers/alibaba/example_adb_spark_batch.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +# Ignore missing args provided by default_args +# type: ignore[call-arg] +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.alibaba.cloud.operators.analyticdb_spark import AnalyticDBSparkBatchOperator + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +DAG_ID = "adb_spark_batch_dag" +# [START howto_operator_adb_spark_batch] +with DAG( + dag_id=DAG_ID, + start_date=datetime(2021, 1, 1), + default_args={"cluster_id": "your cluster", "rg_name": "your resource group", "region": "your region"}, + max_active_runs=1, + catchup=False, +) as dag: + + spark_pi = AnalyticDBSparkBatchOperator( + task_id="task1", + file="local:///tmp/spark-examples.jar", + class_name="org.apache.spark.examples.SparkPi", + ) + + spark_lr = AnalyticDBSparkBatchOperator( + task_id="task2", + file="local:///tmp/spark-examples.jar", + class_name="org.apache.spark.examples.SparkLR", + ) + + spark_pi >> spark_lr + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() +# [END howto_operator_adb_spark_batch] + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/tests/system/providers/alibaba/example_adb_spark_sql.py b/tests/system/providers/alibaba/example_adb_spark_sql.py new file mode 100644 index 0000000000000..851880fa7355b --- /dev/null +++ b/tests/system/providers/alibaba/example_adb_spark_sql.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +# Ignore missing args provided by default_args +# type: ignore[call-arg] +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.alibaba.cloud.operators.analyticdb_spark import AnalyticDBSparkSQLOperator + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +DAG_ID = "adb_spark_sql_dag" +# [START howto_operator_adb_spark_sql] +with DAG( + dag_id=DAG_ID, + start_date=datetime(2021, 1, 1), + default_args={"cluster_id": "your cluster", "rg_name": "your resource group", "region": "your region"}, + max_active_runs=1, + catchup=False, +) as dag: + + show_databases = AnalyticDBSparkSQLOperator(task_id="task1", sql="SHOE DATABASES;") + + show_tables = AnalyticDBSparkSQLOperator(task_id="task2", sql="SHOW TABLES;") + + show_databases >> show_tables + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() +# [END howto_operator_adb_spark_sql] + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)