diff --git a/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml b/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml index 639d2b2ecb3aa..64973cf02a209 100644 --- a/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml +++ b/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml @@ -41,6 +41,7 @@ body: - asana - celery - cloudant + - cloudera - cncf-kubernetes - databricks - datadog diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 34b220405e6b8..0fafb29d8d7b8 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -618,10 +618,10 @@ This is the full list of those extras: airbyte, alibaba, all, all_dbs, amazon, apache.atlas, apache.beam, apache.cassandra, apache.drill, apache.druid, apache.hdfs, apache.hive, apache.kylin, apache.livy, apache.pig, apache.pinot, apache.spark, apache.sqoop, apache.webhdfs, asana, async, atlas, aws, azure, cassandra, celery, -cgroups, cloudant, cncf.kubernetes, crypto, dask, databricks, datadog, dbt.cloud, deprecated_api, -devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc, docker, druid, elasticsearch, -exasol, facebook, ftp, gcp, gcp_api, github, github_enterprise, google, google_auth, grpc, -hashicorp, hdfs, hive, http, imap, influxdb, jdbc, jenkins, jira, kerberos, kubernetes, ldap, +cgroups, cloudant, cloudera, cncf.kubernetes, crypto, dask, databricks, datadog, dbt.cloud, +deprecated_api, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc, docker, druid, +elasticsearch, exasol, facebook, ftp, gcp, gcp_api, github, github_enterprise, google, google_auth, +grpc, hashicorp, hdfs, hive, http, imap, influxdb, jdbc, jenkins, jira, kerberos, kubernetes, ldap, leveldb, microsoft.azure, microsoft.mssql, microsoft.psrp, microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas, opsgenie, oracle, pagerduty, pandas, papermill, password, pinot, plexus, postgres, presto, qds, qubole, rabbitmq, redis, s3, salesforce, samba, segment, sendgrid, sentry, @@ -690,6 +690,7 @@ apache.beam google apache.druid apache.hive apache.hive amazon,microsoft.mssql,mysql,presto,samba,vertica apache.livy http +cloudera apache.hive,http dbt.cloud http dingding http discord http diff --git a/INSTALL b/INSTALL index 4354d5e33f558..45acb9e78fb51 100644 --- a/INSTALL +++ b/INSTALL @@ -98,10 +98,10 @@ The list of available extras: airbyte, alibaba, all, all_dbs, amazon, apache.atlas, apache.beam, apache.cassandra, apache.drill, apache.druid, apache.hdfs, apache.hive, apache.kylin, apache.livy, apache.pig, apache.pinot, apache.spark, apache.sqoop, apache.webhdfs, asana, async, atlas, aws, azure, cassandra, celery, -cgroups, cloudant, cncf.kubernetes, crypto, dask, databricks, datadog, dbt.cloud, deprecated_api, -devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc, docker, druid, elasticsearch, -exasol, facebook, ftp, gcp, gcp_api, github, github_enterprise, google, google_auth, grpc, -hashicorp, hdfs, hive, http, imap, influxdb, jdbc, jenkins, jira, kerberos, kubernetes, ldap, +cgroups, cloudant, cloudera, cncf.kubernetes, crypto, dask, databricks, datadog, dbt.cloud, +deprecated_api, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc, docker, druid, +elasticsearch, exasol, facebook, ftp, gcp, gcp_api, github, github_enterprise, google, google_auth, +grpc, hashicorp, hdfs, hive, http, imap, influxdb, jdbc, jenkins, jira, kerberos, kubernetes, ldap, leveldb, microsoft.azure, microsoft.mssql, microsoft.psrp, microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas, opsgenie, oracle, pagerduty, pandas, papermill, password, pinot, plexus, postgres, presto, qds, qubole, rabbitmq, redis, s3, salesforce, samba, segment, sendgrid, sentry, diff --git a/airflow/providers/cloudera/CHANGELOG.rst b/airflow/providers/cloudera/CHANGELOG.rst new file mode 100644 index 0000000000000..cef7dda80708a --- /dev/null +++ b/airflow/providers/cloudera/CHANGELOG.rst @@ -0,0 +1,25 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.0 +..... + +Initial version of the provider. diff --git a/airflow/providers/cloudera/__init__.py b/airflow/providers/cloudera/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/cloudera/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/cloudera/hooks/__init__.py b/airflow/providers/cloudera/hooks/__init__.py new file mode 100644 index 0000000000000..91ea1272c4ad3 --- /dev/null +++ b/airflow/providers/cloudera/hooks/__init__.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Covers root exception for CDP hooks""" +from typing import Optional + + +class CdpHookException(Exception): + """Root exception for custom Cloudera hooks, which is used to handle any known exceptions""" + + def __init__(self, raised_from: Optional[Exception] = None, msg: Optional[str] = None) -> None: + super().__init__(raised_from, msg) + self.raised_from = raised_from + + def __str__(self) -> str: + return self.__repr__() diff --git a/airflow/providers/cloudera/hooks/cde_hook.py b/airflow/providers/cloudera/hooks/cde_hook.py new file mode 100644 index 0000000000000..c5621dcb66d86 --- /dev/null +++ b/airflow/providers/cloudera/hooks/cde_hook.py @@ -0,0 +1,331 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Holds airflow hook functionalities for CDE clusters like submitting a CDE job, +checking its status or killing it. +""" + +import os +from typing import Any, Dict, Optional, Set + +import requests +import tenacity # type: ignore + +from airflow.exceptions import AirflowException # type: ignore +from airflow.hooks.base import BaseHook # type: ignore +from airflow.providers.cloudera.hooks import CdpHookException +from airflow.providers.cloudera.model.cdp.cde import VirtualCluster +from airflow.providers.cloudera.model.connection import CdeConnection +from airflow.providers.cloudera.security import SecurityError +from airflow.providers.cloudera.security.cde_security import BearerAuth, CdeApiTokenAuth, CdeTokenAuthResponse +from airflow.providers.cloudera.security.cdp_security import CdpAccessKeyCredentials, CdpAccessKeyV2TokenAuth +from airflow.providers.cloudera.security.token_cache import EncryptedFileTokenCacheStrategy +from airflow.providers.http.hooks.http import HttpHook # type: ignore + + +class CdeHookException(CdpHookException): + """Root exception for the CdeHook which is used to handle any known exceptions""" + + +class CdeHook(BaseHook): # type: ignore + """A wrapper around the CDE Virtual Cluster REST API.""" + + conn_name_attr = "cde_conn_id" + conn_type = "cloudera_data_engineering" + hook_name = "Cloudera Data Engineering" + + @staticmethod + def get_ui_field_behaviour() -> Dict: + """Returns custom field behaviour""" + return { + "hidden_fields": ["schema", "port"], + "relabeling": { + "host": "Virtual Cluster API endpoint", + "login": "CDP Access Key", + "password": "CDP Private Key", + }, + } + + DEFAULT_CONN_ID = "cde_runtime_api" + # Gives a total of at least 2^8+2^7+...2=510 seconds of retry with exponential backoff + DEFAULT_NUM_RETRIES = 9 + DEFAULT_API_TIMEOUT = 30 + + def __init__( + self, + connection_id: str = DEFAULT_CONN_ID, + num_retries: int = DEFAULT_NUM_RETRIES, + api_timeout: int = DEFAULT_API_TIMEOUT, + ) -> None: + """ + Create a new CdeHook. The connection parameters are eagerly validated to highlight + any problems early. + + :param connection_id: The connection name for the target virtual cluster API + (default: {CdeHook.DEFAULT_CONN_ID}). + :param num_retries: The number of times API requests should be retried if a server-side + error or transport error is encountered (default: {CdeHook.DEFAULT_NUM_RETRIES}). + :param api_timeout: The timeout in seconds after which, if no response has been received + from the API, a request should be abandoned and retried + (default: {CdeHook.DEFAULT_API_TIMEOUT}). + """ + super().__init__(connection_id) + self.cde_conn_id = connection_id + airflow_connection = self.get_connection(self.cde_conn_id) + self.connection = CdeConnection.from_airflow_connection(airflow_connection) + self.num_retries = num_retries + self.api_timeout = api_timeout + + def _do_api_call( + self, method: str, endpoint: str, params: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """ + Execute the API call. Requests are retried for connection errors and server-side errors + using an exponential backoff. + + :param method: HTTP method + :param endpoint: URL path of REST endpoint, excluding the API prefix, e.g "/jobs/myjob/run". + If the endpoint does not start with '/' this will be added + :param params: A dictionary of parameters to send in either HTTP body as a JSON document + or as URL parameters for GET requests + :param body: A dictionary to send in the HTTP body as a JSON document + :return: The API response converted to a Python dictionary + or an AirflowException if the API returns an error + """ + + if self.connection.proxy: + self.log.debug("Setting up proxy environment variables") + os.environ["HTTPS_PROXY"] = self.connection.proxy + os.environ["https_proxy"] = self.connection.proxy + + if self.connection.is_external(): + cde_token = self.get_cde_token() + else: + self.log.info("Using internal authentication mechanisms.") + + endpoint = endpoint if endpoint.startswith("/") else f"/{endpoint}" + if self.connection.is_internal(): + endpoint = self.connection.api_base_route + endpoint + + self.log.debug( + "Executing API call: (Method: %s, Endpoint: %s, Parameters: %s)", + method, + endpoint, + params, + ) + http = HttpHook(method.upper(), http_conn_id=self.cde_conn_id) + retry_handler = RetryHandler() + + try: + extra_options: Dict[str, Any] = dict( + timeout=self.api_timeout, + # we check the response ourselves in RetryHandler + check_response=False, + ) + + if self.connection.insecure: + self.log.debug("Setting session verify to False") + extra_options = {**extra_options, "verify": False} + else: + ca_cert = self.connection.ca_cert_path + self.log.debug("ca_cert is %s", ca_cert) + if ca_cert: + self.log.debug("Setting session verify to %s", ca_cert) + extra_options = {**extra_options, "verify": ca_cert} + else: + # Ensures secure connection by default, it is False in Airflow 1 + extra_options = {**extra_options, "verify": True} + + # Small hack to override the insecure header property passed from the + # extra in HTTPHook, which is a boolean but must be a string to be part + # of the headers + request_extra_headers = {"insecure": str(self.connection.insecure)} + + common_kwargs: Dict[str, Any] = dict( + _retry_args=dict( + wait=tenacity.wait_exponential(), + stop=tenacity.stop_after_attempt(self.num_retries), + retry=retry_handler, + ), + endpoint=endpoint, + extra_options=extra_options, + headers=request_extra_headers, + ) + + if self.connection.is_external(): + common_kwargs = {**common_kwargs, "auth": BearerAuth(cde_token)} + + if method.upper() == "GET": + response = http.run_with_advanced_retry(data=params, **common_kwargs) + else: + response = http.run_with_advanced_retry(json=params, **common_kwargs) + return response.json() + except Exception as err: + msg = "API call returned error(s)" + msg = f"{msg}:[{','.join(retry_handler.errors)}]" if retry_handler.errors else msg + self.log.error(msg) + raise CdeHookException(err) from err + + def get_cde_token(self) -> str: + """ + Obtains valid CDE token through CDP access token + + Returns: + cde_token: a valid token for submitting request to the CDE Cluster + """ + self.log.debug("Starting CDE token acquisition") + access_key, private_key = ( + self.connection.access_key, + self.connection.private_key, + ) + vcluster_endpoint = self.connection.get_vcluster_jobs_api_url() + try: + cdp_cred = CdpAccessKeyCredentials(access_key, private_key) + cde_vcluster = VirtualCluster(vcluster_endpoint) + cdp_auth = CdpAccessKeyV2TokenAuth( + cde_vcluster.get_service_id(), + cdp_cred, + cdp_endpoint=self.connection.cdp_endpoint, + altus_iam_endpoint=self.connection.altus_iam_endpoint, + ) + + cache_mech_extra_kw = {} + cache_dir = self.connection.cache_dir + if cache_dir: + cache_mech_extra_kw = {"cache_dir": cache_dir} + + cache_mech = EncryptedFileTokenCacheStrategy( + CdeTokenAuthResponse, + encryption_key=cdp_auth.get_auth_secret(), + **cache_mech_extra_kw, + ) + + cde_auth = CdeApiTokenAuth( + cde_vcluster, + cdp_auth, + cache_mech, + custom_ca_certificate_path=self.connection.ca_cert_path, + insecure=self.connection.insecure, + ) + cde_token = cde_auth.get_cde_authentication_token().access_token + self.log.debug("CDE token successfully acquired") + + except SecurityError as err: + self.log.error( + "Failed to get the cde auth token for the connection %s, error: %s", + self.cde_conn_id, + err, + ) + raise CdeHookException(err) from err + + return cde_token + + def submit_job( + self, + job_name: str, + variables: Optional[Dict[str, Any]] = None, + overrides: Optional[Dict[str, Any]] = None, + proxy_user: Optional[str] = None, + ) -> int: + """ + Submit a job run request + + :param job_name: The name of the job definition to run (should already be + defined in the virtual cluster). + :param variables: Runtime variables to pass to job run + :param overrides: Overrides of job parameters for this run + :return: the job run ID for a successful submission or an AirflowException + :rtype: int + """ + if proxy_user: + self.log.warning("Proxy user is not yet supported. Setting it to None.") + + body = dict( + variables=variables, + overrides=overrides, + # Shall be updated to proxy_user when we support this feature + user=None, + ) + response = self._do_api_call("POST", f"/jobs/{job_name}/run", body) + return response["id"] + + def kill_job_run(self, run_id: int) -> None: + """ + Kill a running job + + :param run_id: the run ID of the job run + """ + self._do_api_call("POST", f"/job-runs/{run_id}/kill") + + def check_job_run_status(self, run_id: int) -> str: + """ + Check and return the status of a job run + + :param run_id: the run ID of the job run + :return: the job run status + :rtype: str + """ + response = self._do_api_call("GET", f"/job-runs/{run_id}") + response_status: str = response["status"] + return response_status + + def get_conn(self): + raise NotImplementedError + + def get_pandas_df(self, sql): + raise NotImplementedError + + def get_records(self, sql): + raise NotImplementedError + + +class RetryHandler: + """ + Retry strategy for tenacity that retries if a 5xx response + or certain exceptions are encountered. + Client error (4xx) responses are considered fatal. + """ + + ALWAYS_RETRY_EXCEPTIONS = ( + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + ) + + def __init__(self) -> None: + self._errors: Set[Any] = set() + + @property + def errors(self) -> Set[Any]: + """The set of unique API call error messages if any.""" + return self._errors + + def __call__(self, attempt: Any) -> bool: + if attempt.failed: + return isinstance(attempt.exception(), self.ALWAYS_RETRY_EXCEPTIONS) + else: + if isinstance(attempt.result(), requests.Response): + response = attempt.result() + status = str(response.status_code) + ":" + response.reason + error_msg = (status + ":" + response.text.rstrip()) if response.text else status + self._errors.add(error_msg) + if response.status_code < 400: + return False + elif response.status_code >= 500 and response.status_code < 600: + return True + else: + raise AirflowException(error_msg) + return False diff --git a/airflow/providers/cloudera/hooks/cdw_hook.py b/airflow/providers/cloudera/hooks/cdw_hook.py new file mode 100644 index 0000000000000..7e98a716f10d2 --- /dev/null +++ b/airflow/providers/cloudera/hooks/cdw_hook.py @@ -0,0 +1,376 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import csv +import os +import subprocess +import time +from io import StringIO +from tempfile import NamedTemporaryFile + +from airflow.exceptions import AirflowException +from airflow.providers.apache.hive.hooks.hive import HiveCliHook +from airflow.utils.file import TemporaryDirectory +from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING + +HIVE_QUEUE_PRIORITIES = ["VERY_HIGH", "HIGH", "NORMAL", "LOW", "VERY_LOW"] +JDBC_BACKEND_HIVE = "hive2" +JDBC_BACKEND_IMPALA = "impala" + + +def get_context_from_env_var(): + """ + Extract context from env variable, e.g. dag_id, task_id and execution_date, + so that they can be used inside BashOperator and PythonOperator. + + :return: The context of interest. + """ + return { + format_map["default"]: os.environ.get(format_map["env_var_format"], "") + for format_map in AIRFLOW_VAR_NAME_FORMAT_MAPPING.values() + } + + +class CdwHook(HiveCliHook): + """Simple CDW hive cli hook which extends the functionality of HiveCliHook + in order to conform the parameter needs. + + :param cli_conn_id: airflow connection id to be used. + :param query_isolation: controls whether to use cdw's query isolation feature. + Only hive warehouses support this at the moment. + :jdbc_driver: a safety valve for JDBC driver class. It's not supposed to be changed by default as + CdwHook guesses and uses the correct driver for impala. The environment provides both JDBC 4.1 + and JDBC 4.2 driver. Currently, JDBC 4.1 is used for cdw. + For hive, the driver class is not defined at all in beeline cli. + """ + + def __init__( + self, + cli_conn_id=None, + query_isolation=True, + jdbc_driver="com.cloudera.impala.jdbc41.Driver", + ): + super().__init__(cli_conn_id) + self.conn = self.get_connection(cli_conn_id) + self.query_isolation = query_isolation + self.jdbc_driver = jdbc_driver if jdbc_driver is not None else "com.cloudera.impala.jdbc41.Driver" + self.sub_process = None + + def get_cli_cmd(self): + """This is supposed to be visible for testing.""" + return self._prepare_cli_cmd() + + def _prepare_cli_cmd(self, hide_secrets=False): + """ + This function creates the command list from available information. + :param hide_secrets: whether to mask secrets with asterisk + """ + conn = self.conn + cmd_extra = [] + + hive_bin = "beeline" # only beeline is supported as client while connecting to CDW + jdbc_backend = CdwHook.get_jdbc_backend(conn) + + jdbc_url = f"jdbc:{jdbc_backend}://{conn.host}{CdwHook.get_port_string(conn)}/{conn.schema}" + + # HTTP+SSL is default for CDW, but it can be overwritten in connection extra params if needed + if jdbc_backend == JDBC_BACKEND_IMPALA: + jdbc_url = self.add_parameter_to_jdbc_url(conn.extra_dejson, jdbc_url, "AuthMech", "3") + jdbc_url = self.add_parameter_to_jdbc_url(conn.extra_dejson, jdbc_url, "transportMode", "http") + jdbc_url = self.add_parameter_to_jdbc_url(conn.extra_dejson, jdbc_url, "httpPath", "cliservice") + jdbc_url = self.add_parameter_to_jdbc_url( + conn.extra_dejson, jdbc_url, "ssl", CdwHook.get_ssl_parameter(conn) + ) + + if jdbc_backend == JDBC_BACKEND_IMPALA: + cmd_extra += ["-d", self.jdbc_driver] + + cmd_extra += ["-u", jdbc_url] + if conn.login: + cmd_extra += ["-n", conn.login] + if conn.password: + cmd_extra += ["-p", conn.password if not hide_secrets else "********"] + + self.add_extra_parameters(jdbc_backend, cmd_extra) + + return [hive_bin] + cmd_extra + + def add_extra_parameters(self, jdbc_backend, cmd_extra): + """ + Adds extra parameters to the beeline command in addition to the basic, needed ones. + This can be overridden in subclasses in order to change beeline behavior. + """ + # this hive option is supposed to enforce query isolation regardless + # of the initial settings used while creating the virtual warehouse + if self.query_isolation and jdbc_backend == JDBC_BACKEND_HIVE: + cmd_extra += ["--hiveconf", "hive.query.isolation.scan.size.threshold=0B"] + cmd_extra += ["--hiveconf", "hive.query.results.cache.enabled=false"] + cmd_extra += [ + "--hiveconf", + "hive.auto.convert.join.noconditionaltask.size=2505397589", + ] + + @staticmethod + def get_jdbc_backend(conn): + """ + Tries to guess the underlying database from connection host. In CDW, JDBC urls are like below: + hive: + - hs2-lbodor-airflow-hive.env-xkg48s.dwx.dev.cldr.work + impala: + - impala-proxy-lbodor-airflow-impala.env-xkg48s.dwx.dev.cldr.work:443 + - coordinator-lbodor-impala-test.env-xkg48s.dwx.dev.cldr.work:443 + So this method returns the database kind string which can be used in jdbc string: + hive: 'hive2' + impala: 'impala' + """ + return ( + JDBC_BACKEND_IMPALA + if (conn.host.find("coordinator-") == 0 or conn.host.find("impala-proxy") == 0) + else JDBC_BACKEND_HIVE + ) + + @staticmethod + def get_port_string(conn): + """ + hive: '' + impala: ':443' + """ + backend = CdwHook.get_jdbc_backend(conn) + return ":443" if backend == JDBC_BACKEND_IMPALA else "" + + @staticmethod + def get_ssl_parameter(conn): + """ + hive: 'true' + impala: '1' + """ + backend = CdwHook.get_jdbc_backend(conn) + return "1" if backend == JDBC_BACKEND_IMPALA else "true" + + def run_cli(self, hql, schema="default", verbose=True, hive_conf=None): + """Copied from hive hook, but removed unnecessary parts, e.g. mapred queue.""" + conn = self.conn + schema = schema or conn.schema + if schema: + hql = f"USE {schema};\n{hql}" + + with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir: + with NamedTemporaryFile(dir=tmp_dir) as f: + hql = hql + "\n" + f.write(hql.encode("UTF-8")) + f.flush() + hive_cmd = self._prepare_cli_cmd() + env_context = get_context_from_env_var() + # Only extend the hive_conf if it is defined. + if hive_conf: + env_context.update(hive_conf) + hive_conf_params = self._prepare_hiveconf(env_context) + hive_cmd.extend(hive_conf_params) + hive_cmd.extend(["-f", f.name]) + + if verbose: + self.log.info("%s", " ".join(self._prepare_cli_cmd(hide_secrets=True))) + sub_process = subprocess.Popen( + hive_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + cwd=tmp_dir, + close_fds=True, + ) + self.sub_process = sub_process + stdout = "" + while True: + line = sub_process.stdout.readline() + if not line: + break + stdout += line.decode("UTF-8") + if verbose: + self.log.info(line.decode("UTF-8").strip()) + sub_process.wait() + + if sub_process.returncode: + raise AirflowException(stdout) + + return stdout + + def kill(self): + if hasattr(self, "sub_process") and self.sub_process is not None: + if self.sub_process.poll() is None: + print("Killing the Hive job") + self.sub_process.terminate() + time.sleep(60) + self.sub_process.kill() + + @staticmethod + def add_parameter_to_jdbc_url(extra_dejson, jdbc_url, parameter_name, default_value=None): + """ + Appends a parameter to jdbc url if found in connection json extras + or there is a not None default value. + """ + if extra_dejson is None or extra_dejson.get(parameter_name, default_value) is None: + return jdbc_url + + return jdbc_url + f";{parameter_name}={extra_dejson.get(parameter_name, default_value)}" + + +class CdwHiveMetastoreHook(CdwHook): + """A hive metastore hook which should behave the same as HiveMetastoreHook, + but instead of a kerberized, binary thrift connection it uses beeline as the client, + which connects to sys database. + """ + + def __init__(self, cli_conn_id="metastore_default"): + """ + In CdwHiveMetastoreHook this is supposed to be a beeline connection, + pointing to sys schema, so the conn should point to a hive cli wrapper connection in airflow, + similarly to CdwHook's cli_conn_id. + """ + super().__init__(cli_conn_id=cli_conn_id) + self.conn.schema = "sys" # metastore database + + def check_for_partition(self, schema, table, partition): + """ + Checks whether a partition exists + + :param schema: Name of hive schema (database) @table belongs to + :param table: Name of hive table @partition belongs to + :partition: Expression that matches the partitions to check for + :rtype: bool + """ + hql = ( + "select dbs.name as db_name, tbls.tbl_name as tbl_name, partitions.part_name as " + "part_name from partitions left outer join tbls on tbls.tbl_id = partitions.tbl_id left " + f"outer join dbs on dbs.db_id = tbls.db_id where dbs.name = '{schema}' and tbls.tbl_name " + f"= '{table}' and partitions.part_name = '{partition}';" + ) + + response = self.run_cli(hql, self.conn.schema, verbose=True, hive_conf=None) + result_lines = CdwHiveMetastoreHook.parse_csv_lines(response) + results_without_header = CdwHiveMetastoreHook.get_results_without_header( + result_lines, "db_name,tbl_name,part_name" + ) + + self.log.info("partitions: %s", results_without_header) + return len(results_without_header) > 0 + + def check_for_named_partition(self, schema, table, partition): + """ + Checks whether a partition with a given name exists + + :param schema: Name of hive schema (database) @table belongs to + :param table: Name of hive table @partition belongs to + :partition: Name of the partitions to check for (eg `a=b/c=d`) + :rtype: bool + """ + raise Exception("TODO IMPLEMENT") + + def get_table(self, table_name, db="default"): + """Get a metastore table object""" + if db == "default" and "." in table_name: + db, table_name = table_name.split(".")[:2] + hql = ( + "select dbs.name as db_name, tbls.tbl_name as tbl_namefrom tbls left outer join dbs on " + f"dbs.db_id = tbls.db_id where dbs.name = '{db}' and tbls.tbl_name = '{table_name}' " + ) + + response = self.run_cli(hql, self.conn.schema, verbose=True, hive_conf=None) + result_lines = CdwHiveMetastoreHook.parse_csv_lines(response) + + tables = CdwHiveMetastoreHook.get_results_without_header(result_lines, "db_name,tbl_name") + return tables + + def get_tables(self, db, pattern="*"): + """Get a metastore table object.""" + hql = ( + "select dbs.name as db_name, tbls.tbl_name as tbl_namefrom tbls left outer join dbs on dbs.db_id" + f" = tbls.db_id where dbs.name = '{db}' and tbls.tbl_name like '{pattern.replace('*', '%')}' " + ) + response = self.run_cli(hql, self.conn.schema, verbose=True, hive_conf=None) + result_lines = CdwHiveMetastoreHook.parse_csv_lines(response) + + tables = CdwHiveMetastoreHook.get_results_without_header(result_lines, "db_name,tbl_name") + + self.log.info("tables: %s", tables) + return len(tables) > 0 + + def get_databases(self, pattern="*"): + """Get a metastore table object.""" + hql = f"select dbs.name from dbs where dbs.name LIKE '{pattern.replace('*', '%')}' " + + response = self.run_cli(hql, self.conn.schema, verbose=True, hive_conf=None) + result_lines = CdwHiveMetastoreHook.parse_csv_lines(response) + + databases = CdwHiveMetastoreHook.get_results_without_header(result_lines, "db_name,tbl_name") + + self.log.info("databases: %s", databases) + return databases + + def get_partitions(self, schema, table_name, partition_filter=None): + """Returns a list of all partitions in a table.""" + raise Exception("TODO IMPLEMENT") + + def max_partition(self, schema, table_name, field=None, filter_map=None): + """ + Returns the maximum value for all partitions with given field in a table. + If only one partition key exist in the table, the key will be used as field. + filter_map should be a partition_key:partition_value map and will be used to + filter out partitions. + + :param schema: schema name. + :param table_name: table name. + :param field: partition key to get max partition from. + :param filter_map: partition_key:partition_value map used for partition filtering. + """ + raise Exception("TODO IMPLEMENT") + + def table_exists(self, table_name, db="default"): + """Check if table exists.""" + tables = self.get_table(table_name, db) + return len(tables) > 0 + + @staticmethod + def parse_csv_lines(response): + r""" + Parses a csv string by generating a list of lists. + E.g. + Input: 'cdw,no\ne,two' + Output: [['cdw', 'no'], ['e', 'two']] + """ + readable_input = StringIO(response) + return list(csv.reader(readable_input, delimiter=",")) + + @staticmethod + def get_results_without_header(result_lines, header): + """Parses beeline output and removes noise before the given reader (e.g. SLF4J warnings).""" + final_list = [] + add_line = False + for line in result_lines: + if add_line: + final_list.append(line) + elif ",".join(line) == header: + add_line = True + return final_list + + def add_extra_parameters(self, jdbc_backend, cmd_extra): + """ + Overrides CdwHook.add_extra_parameters in order to enable behavior + which is optimal for fetching and parsing metadata in csv. + """ + cmd_extra += ["--hiveconf", "hive.query.isolation.scan.size.threshold=1GB"] + cmd_extra += ["--silent=true"] + cmd_extra += ["--outputformat=csv2"] + cmd_extra += ["--showHeader=true"] diff --git a/airflow/providers/cloudera/model/__init__.py b/airflow/providers/cloudera/model/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/cloudera/model/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/cloudera/model/cdp/__init__.py b/airflow/providers/cloudera/model/cdp/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/cloudera/model/cdp/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/cloudera/model/cdp/cde.py b/airflow/providers/cloudera/model/cdp/cde.py new file mode 100644 index 0000000000000..71d9c60ed8426 --- /dev/null +++ b/airflow/providers/cloudera/model/cdp/cde.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Holds model objects related to CDE Environments.""" + +import re +from urllib.parse import urlparse + + +class VirtualCluster: + """Represents a CDE Virtual cluster and hold various helper methods.""" + + ACCESS_KEY_AUTH_ENDPOINT_PATH = "/gateway/cdptkn/knoxtoken/api/v1/token" + + def __init__(self, vcluster_endpoint: str) -> None: + """ + Args: + vcluster_endpoint: the endpoint of the Virtual Cluster corresponding to the + Job API URL + """ + self.vcluster_endpoint = vcluster_endpoint + auth_endpoint_url = urlparse(vcluster_endpoint) + self.host = auth_endpoint_url.hostname + + def get_service_id(self) -> str: + """ + Obtains cluster id from the virtual cluster endpoint. + + Returns: + cluster id string in the form 'cluster-' + + Raises: + ValueError: When the given url is not in the expected format + """ + pattern = re.compile("cde-[a-zA-Z0-9]*") + try: + first_match = pattern.findall(self.vcluster_endpoint)[0] + return re.sub("^cde-", "cluster-", first_match, 1) + except IndexError as err: + raise ValueError(f"Cluster ID not found in {self.vcluster_endpoint}") from err + + def get_auth_endpoint(self) -> str: + """ + Derive the authentication endpoint from the virtual cluster cluster endpoint + + Returns: + Endpoint of the authentication service + + Raises: + ValueError of the input has incorrect form + """ + auth_endpoint = re.sub("^https://[a-zA-Z0-9]*", "https://service", self.vcluster_endpoint, 1) + + if auth_endpoint == self.vcluster_endpoint: + raise ValueError( + f"Invalid vcluster endpoint given: {self.vcluster_endpoint}", + ) + + auth_endpoint_url = urlparse(auth_endpoint) + auth_endpoint_url = auth_endpoint_url._replace(path=self.ACCESS_KEY_AUTH_ENDPOINT_PATH) + return auth_endpoint_url.geturl() diff --git a/airflow/providers/cloudera/model/connection.py b/airflow/providers/cloudera/model/connection.py new file mode 100644 index 0000000000000..d2bbeb91021f2 --- /dev/null +++ b/airflow/providers/cloudera/model/connection.py @@ -0,0 +1,167 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Holds connections for the Cloudera Products""" + +import json +from json.decoder import JSONDecodeError +from typing import Optional +from urllib.parse import urlparse + +from airflow.models.connection import Connection + + +class CdeConnection(Connection): + """Connection details to the Cloudera Data Engineering product""" + + CDE_API_PREFIX = "/api/v1" + + def __init__( + self, + connection_id: str, + scheme: str, + host: str, + api_base_route: str, + access_key: str, + private_key: str, + port: Optional[int] = None, + cache_dir: Optional[str] = None, + ca_cert_path: Optional[str] = None, + proxy: Optional[str] = None, + cdp_endpoint: Optional[str] = None, + altus_iam_endpoint: Optional[str] = None, + insecure: bool = False, + ) -> None: + super().__init__( + conn_id=connection_id, + host=host, + login=access_key, + password=private_key, + port=port, + ) + self.conn_type = "cloudera_data_engineering" + self.scheme = scheme + self.api_base_route = api_base_route + self.cache_dir = cache_dir + self.ca_cert_path = ca_cert_path + self.proxy = proxy + self.cdp_endpoint = cdp_endpoint + self.altus_iam_endpoint = altus_iam_endpoint + self.insecure = insecure + + def is_external(self) -> bool: + """Checks if connection is external. External connections + are typically cross-services connections or connection defined + in an external Airflow instance. + + Returns: + True of connection is external, false otherwise + """ + return not self.is_internal() + + def is_internal(self) -> bool: + """Checks if connection is internal. Internal connections + are only meant to be used within a CDE service and are managed + automatically by the Virtual Cluster. + + Returns: + True of connection is internal, false otherwise + """ + return self.__internal_connection(self.host) + + def get_vcluster_jobs_api_url(self) -> str: + """Constructs the jobs api url from the elements defined in the connection. + + Returns: + vcluster_jobs_api_url: the jobs api url + """ + + vcluster_jobs_api_url = f"{self.scheme}://{self.host}" + if self.port: + vcluster_jobs_api_url += ":" + str(self.port) + vcluster_jobs_api_url += self.api_base_route + return vcluster_jobs_api_url + + @property + def access_key(self) -> str: + """CDP Access key + + Returns: + the access key associated to the connection + """ + return self.login + + @property + def private_key(self) -> str: + """CDP Private key + + Returns: + the private key associated to the connection + """ + # Relies on Airflow Connection password getter, + # so that the password is not stored in clear in the memory + return self.password + + @classmethod + def __internal_connection(cls, hostname: str) -> bool: + return hostname.endswith(".svc") or hostname.endswith(".svc.cluster.local") + + @classmethod + def from_airflow_connection(cls, conn: Connection) -> "CdeConnection": + """Factory method for constructing a CDE connection from an Airflow Connection. + + Args: + conn: an Airflow Connection instance + + Returns: + A new CDE connection with the parameters derived from the Airflow connection + """ + try: + if conn.extra: + extra = json.loads(conn.extra) + else: + extra = {} + except JSONDecodeError as err: + raise ValueError(f"Invalid extra property: {repr(err)}") from err + if conn.host and "://" in conn.host: + conn_uri = conn.host + else: + conn_uri = conn.get_uri() + connection_url = urlparse(conn_uri) + + # Internal endpoints have base prefix + api_base_route = ( + cls.CDE_API_PREFIX if cls.__internal_connection(connection_url.hostname) else connection_url.path + ) + + return cls( + conn.conn_id, + connection_url.scheme, + connection_url.hostname, + api_base_route, + conn.login, + conn.password, + port=conn.port, + cache_dir=extra.get("cache_dir"), + ca_cert_path=extra.get("ca_cert_path"), + proxy=extra.get("proxy"), + cdp_endpoint=extra.get("cdp_endpoint"), + altus_iam_endpoint=extra.get("altus_iam_endpoint"), + insecure=extra.get("insecure", False), + ) + + def __repr__(self) -> str: + return repr(self.__dict__) diff --git a/airflow/providers/cloudera/operators/__init__.py b/airflow/providers/cloudera/operators/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/cloudera/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/cloudera/operators/cde_operator.py b/airflow/providers/cloudera/operators/cde_operator.py new file mode 100644 index 0000000000000..65c2737926e56 --- /dev/null +++ b/airflow/providers/cloudera/operators/cde_operator.py @@ -0,0 +1,246 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import time + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.cloudera.hooks.cde_hook import CdeHook, CdeHookException + + +class CdeRunJobOperator(BaseOperator): + """ + Runs a job in a CDE Virtual Cluster. The ``CdeRunJobOperator`` runs the + named job with optional variables and overrides. The job and its resources + must have already been created via the specified virtual cluster jobs API. + + The virtual cluster API endpoint is specified by setting the + ``connection_id`` parameter. The "local" virtual cluster jobs API is the + default and has a special value of ``cde_runtime_api``. Authentication to + the API is handled automatically and any jobs in the DAG will run as the + user who submitted the DAG. + + Jobs can be defined in a virtual cluster with variable placeholders, + e.g. ``{{ inputdir }}``. Currently the fields supporting variable expansion + are Spark application name, Spark arguments, and Spark configurations. + Variables can be passed to the operator as a dictionary of key-value string + pairs. In addition to any user variables passed via the ``variables`` + parameter, the following standard Airflow macros are automatically + populated as variables by the operator (see + https://airflow.apache.org/docs/stable/macros-ref): + + * ``ds``: the execution date as ``YYYY-MM-DD`` + * ``ds_nodash``: the execution date as ``YYYYMMDD`` + * ``ts``: execution date in ISO 8601 format + * ``ts_nodash``: execution date in ISO 8601 format without '-', ':' or + timezone information + * ``run_id``: the run_id of the current DAG run + + If a CDE job needs to run with a different configuration, a task can be + configured with runtime overrides. For example to override the Spark + executor memory and cores for a task and to supply an additional config + parameter you could supply the following dictionary can be supplied to + the ``overrides`` parameter:: + + { + 'spark': { + 'executorMemory': '8g', + 'executorCores': '4', + 'conf': { + 'spark.kubernetes.memoryOverhead': '2048' + } + } + } + + See the CDE Jobs API documentation for the full list of parameters that + can be overridden. + + Via the ``wait`` parameter, jobs can either be submitted asynchronously to + the API (``wait=False``) or the task can wait until the job is complete + before exiting the task (default is ``wait=True``). If ``wait`` is + ``True``, the task exit status will reflect the final status of the + submitted job (or the task will fail on timeout if specified). If ``wait`` + is ``False`` the task status will reflect whether the job was successfully + submitted to the API or not. + + Note: all parameters below can also be provided through the + ``default_args`` field of the DAG. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CdeRunJobOperator` + + :param job_name: the name of the job in the target cluster, required + :param connection_id: the Airflow connection id for the target API + endpoint, default value ``'cde_runtime_api'`` + :param variables: a dictionary of key-value pairs to populate in the + job configuration, default empty dict. + :param overrides: a dictionary of key-value pairs to override in the + job configuration, default empty dict. + :param wait: if set to true, the operator will wait for the job to + complete in the target cluster. The task exit status will reflect the + status of the completed job. Default ``True`` + :param timeout: the maximum time to wait in seconds for the job to + complete if ``wait=True``. If set to ``None``, 0 or a negative number, + the task will never be timed out. Default ``0``. + :param job_poll_interval: the interval in seconds at which the target API + is polled for the job status. Default ``10``. + :param api_retries: the number of times to retry an API request in the event + of a connection failure or non-fatal API error. Default ``9``. + """ + + template_fields = ("variables", "overrides") + ui_color = "#385f70" + ui_fgcolor = "#fff" + + DEFAULT_WAIT = True + DEFAULT_POLL_INTERVAL = 10 + DEFAULT_TIMEOUT = 0 + DEFAULT_RETRIES = 9 + DEFAULT_CONNECTION_ID = "cde_runtime_api" + + def __init__( + self, + job_name=None, + variables=None, + overrides=None, + connection_id=DEFAULT_CONNECTION_ID, + wait=DEFAULT_WAIT, + timeout=DEFAULT_TIMEOUT, + job_poll_interval=DEFAULT_POLL_INTERVAL, + api_retries=DEFAULT_RETRIES, + user=None, + **kwargs, + ): + super().__init__(**kwargs) + self.job_name = job_name + self.variables = variables or {} + self.overrides = overrides or {} + self.connection_id = connection_id + self.wait = wait + self.timeout = timeout + self.job_poll_interval = job_poll_interval + if user: + self.log.warning("Proxy user is not yet supported. Setting it to None.") + self.user = None + self.api_retries = api_retries + if not self.job_name: + raise ValueError("job_name required") + # Set internal state + self._hook = self.get_hook() + self._job_run_id = -1 + + def execute(self, context): + self._job_run_id = self.submit_job(context) + if self.wait: + self.wait_for_job() + + def on_kill(self): + if self._hook and self._job_run_id > 0: + self.log.info("Task killed, cancelling job run: %d", self._job_run_id) + try: + self._hook.kill_job_run(self._job_run_id) + except CdeHookException as err: + msg = f"Issue while killing CDE job. Exiting. Error details: {err}" + self.log.error(msg) + raise AirflowException(msg) from err + except Exception as err: + msg = ( + "Most probably unhandled error in CDE Airflow plugin." + f" Please report this issue to Cloudera. Details: {err}" + ) + self.log.error(msg) + raise AirflowException(msg) from err + + def get_hook(self): + """Return CdeHook using specified connection""" + return CdeHook(connection_id=self.connection_id, num_retries=self.api_retries) + + def submit_job(self, context): + """Submit a job run request to CDE via the hook""" + # merge user-supplied variables and airflow variables + user_vars = self.variables or {} + airflow_vars = { + "ds": context["ds"], + "ds_nodash": context["ds_nodash"], + "ts": context["ts"], + "ts_nodash": context["ts_nodash"], + "run_id": context["run_id"], + } + merged_vars = {**airflow_vars, **user_vars} + + try: + job_run_id = self._hook.submit_job(self.job_name, variables=merged_vars, overrides=self.overrides) + except CdeHookException as err: + msg = f"Issue while submitting job. Exiting. Error details: {err}" + self.log.error(msg) + raise AirflowException(msg) from err + except Exception as err: + msg = ( + "Most probably unhandled error in CDE Airflow plugin." + f" Please report this issue to Cloudera. Details: {err}" + ) + self.log.error(msg) + raise AirflowException(msg) from err + self.log.info("Job submitted with run id: %s", job_run_id) + + return job_run_id + + def wait_for_job(self): + """Wait for a submitted job run to complete and raise exception if failed""" + self.log.info("Waiting for job completion, job run id: %s", self._job_run_id) + end_time = None + if self.timeout > 0: + self.log.info("Wait timeout set to %d seconds", self.timeout) + end_time = int(time.time()) + self.timeout + + check_time = int(time.time()) + while not end_time or end_time > check_time: + try: + job_status = self._hook.check_job_run_status(self._job_run_id) + except CdeHookException as err: + msg = f"Issue while checking job status. Exiting. Error details: {err}" + self.log.error(msg) + raise AirflowException(msg) from err + except Exception as err: + msg = ( + "Most probably unhandled error in CDE Airflow plugin." + f" Please report this issue to Cloudera. Details: {err}" + ) + self.log.error(msg) + raise AirflowException(msg) from err + if job_status in ("starting", "running"): + msg = ( + f"Job run in {job_status} status," f" checking again in {self.job_poll_interval} seconds" + ) + self.log.info(msg) + elif job_status == "succeeded": + msg = f"Job run completed with {job_status} status" + self.log.info(msg) + return + elif job_status in ("failed", "killed", "unknown"): + msg = f"Job run exited with {job_status} status" + self.log.error(msg) + raise AirflowException(msg) + else: + msg = f"Got unexpected status when polling for job: {job_status}" + self.log.error(msg) + raise AirflowException(msg) + time.sleep(self.job_poll_interval) + check_time = int(time.time()) + + raise TimeoutError(f"Job run did not complete in {self.timeout} seconds") diff --git a/airflow/providers/cloudera/operators/cdw_operator.py b/airflow/providers/cloudera/operators/cdw_operator.py new file mode 100644 index 0000000000000..5542958d97fb2 --- /dev/null +++ b/airflow/providers/cloudera/operators/cdw_operator.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re + +from airflow.models import BaseOperator +from airflow.providers.cloudera.hooks.cdw_hook import CdwHook +from airflow.utils.operator_helpers import context_to_airflow_vars + + +class CdwExecuteQueryOperator(BaseOperator): + """ + Executes hql code in CDW. This class inherits behavior + from HiveOperator, and instantiates a CdwHook to do the work. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CdwExecuteQueryOperator` + """ + + template_fields = ("hql", "schema", "hiveconfs") + template_ext = ( + ".hql", + ".sql", + ) + ui_color = "#522a9f" + ui_fgcolor = "#fff" + + def __init__( + self, + hql, + schema="default", + hiveconfs=None, + hiveconf_jinja_translate=False, + cli_conn_id="hive_cli_default", + jdbc_driver=None, + # new CDW args + use_proxy_user=False, # pylint: disable=unused-argument + query_isolation=True, # TODO: implement + *args, + **kwargs, + ): + + super().__init__(*args, **kwargs) + self.hql = hql + self.schema = schema + self.hiveconfs = hiveconfs or {} + self.hiveconf_jinja_translate = hiveconf_jinja_translate + self.run_as = None + self.cli_conn_id = cli_conn_id + self.jdbc_driver = jdbc_driver + self.query_isolation = query_isolation + # assigned lazily - just for consistency we can create the attribute with a + # `None` initial value, later it will be populated by the execute method. + # This also makes `on_kill` implementation consistent since it assumes `self.hook` + # is defined. + self.hook = None + + def get_hook(self): + """Simply returns a CdwHook with the provided hive cli connection.""" + return CdwHook( + cli_conn_id=self.cli_conn_id, + query_isolation=self.query_isolation, + jdbc_driver=self.jdbc_driver, + ) + + def prepare_template(self): + if self.hiveconf_jinja_translate: + self.hql = re.sub(r"(\$\{(hiveconf:)?([ a-zA-Z0-9_]*)\})", r"{{ \g<3> }}", self.hql) + + def execute(self, context): + self.log.info("Executing: %s", self.hql) + self.hook = self.get_hook() + + if self.hiveconf_jinja_translate: + self.hiveconfs = context_to_airflow_vars(context) + else: + self.hiveconfs.update(context_to_airflow_vars(context)) + + self.log.info("Passing HiveConf: %s", self.hiveconfs) + self.hook.run_cli(hql=self.hql, schema=self.schema, hive_conf=self.hiveconfs) + + def dry_run(self): + self.hook = self.get_hook() + self.hook.test_hql(hql=self.hql) + + def on_kill(self): + if self.hook: + self.hook.kill() diff --git a/airflow/providers/cloudera/provider.yaml b/airflow/providers/cloudera/provider.yaml new file mode 100644 index 0000000000000..b148ca805c1c8 --- /dev/null +++ b/airflow/providers/cloudera/provider.yaml @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-cloudera +name: Cloudera +description: | + `Cloudera `__ + +versions: + - 1.0.0 + +additional-dependencies: + - apache-airflow>=2.0.0 + +integrations: + - integration-name: Cloudera CDE + external-doc-url: https://cloudera.com/ + how-to-guide: + - /docs/apache-airflow-providers-cloudera/operators/cde_run_job.rst + tags: [service] + - integration-name: Cloudera CDW + external-doc-url: https://cloudera.com/ + how-to-guide: + - /docs/apache-airflow-providers-cloudera/operators/execute_query.rst + tags: [service] + +operators: + - integration-name: Cloudera CDE + python-modules: + - airflow.providers.cloudera.operators.cde_operator + - integration-name: Cloudera CDW + python-modules: + - airflow.providers.cloudera.operators.cdw_operator + +hooks: + - integration-name: Cloudera CDE + python-modules: + - airflow.providers.cloudera.hooks.cde_hook + - integration-name: Cloudera CDW + python-modules: + - airflow.providers.cloudera.hooks.cdw_hook + +sensors: + - integration-name: Cloudera CDW + python-modules: + - airflow.providers.cloudera.sensors.cdw_sensor + +hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ + - airflow.providers.cloudera.hooks.cde_hook.CdeHook + - airflow.providers.cloudera.hooks.cdw_hook.CdwHook + +connection-types: + - hook-class-name: airflow.providers.cloudera.hooks.cde_hook.CdeHook + connection-type: CdeConnection + - hook-class-name: airflow.providers.cloudera.hooks.cdw_hook.CdwHook + connection-type: CdeConnection diff --git a/airflow/providers/cloudera/security/__init__.py b/airflow/providers/cloudera/security/__init__.py new file mode 100644 index 0000000000000..8716e90809c5d --- /dev/null +++ b/airflow/providers/cloudera/security/__init__.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Security module for handling authentication to Cloudera Services""" +from abc import ABC, abstractmethod +from http import HTTPStatus +from typing import Any + +import requests +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential + +from airflow.providers.cloudera.hooks import CdpHookException +from airflow.utils.log.logging_mixin import LoggingMixin, logging # type: ignore + +LOG = logging.getLogger(__name__) + + +class SecurityError(CdpHookException): + """Root security exception, to be used to catch any security issue""" + + +class TokenResponse(ABC, LoggingMixin): + """Base class for token responses""" + + @abstractmethod + def is_valid(self) -> bool: + """Check if token is still valid + + Returns: True if token is valid, false otherwise + """ + raise NotImplementedError + + +class ClientError(requests.exceptions.HTTPError): + """When request fails because of a Client side error""" + + +class ServerError(requests.exceptions.HTTPError): + """When request fails because of an Internal/Server side error""" + + +ALWAYS_RETRY_EXCEPTIONS = ( + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + ServerError, +) + + +@retry( + wait=wait_exponential(), + stop=stop_after_attempt(3), + retry=retry_if_exception_type(ALWAYS_RETRY_EXCEPTIONS), + reraise=True, +) +def submit_request(method, uri, *args: Any, **kw_args) -> requests.Response: + """ + Helper method for submitting HTTP request and handling common errors + Args: + method: http method, GET, POST, etc. + uri: endpoint of the requests + args: arguments given to the function + kw_args: keyword arguments given to the function + + Returns: + Response of the http request + + Raises: + ClientError if response status code is 4xx + ServerError if response status code is 5xx + corresponding issued requests.exceptions.RequestException if requests throws an error and + cannot complete successfully + """ + try: + LOG.debug("Issuing request: %s %s", method, uri) + response = requests.request(method, uri, *args, **kw_args) + except requests.exceptions.RequestException as err: + print(err) + LOG.debug("Failed to query endpoint %s %s, error: %s", method, uri, repr(err)) + raise + + if response.status_code >= HTTPStatus.BAD_REQUEST: + status = str(response.status_code) + ":" + response.reason + error_msg = (status + ":" + response.text.rstrip()) if response.text else status + LOG.debug("Failed to query endpoint %s %s, error: %s", method, uri, error_msg) + if response.status_code < HTTPStatus.INTERNAL_SERVER_ERROR: + raise ClientError(error_msg) + if response.status_code >= HTTPStatus.INTERNAL_SERVER_ERROR: + raise ServerError(error_msg) + + return response diff --git a/airflow/providers/cloudera/security/cde_security.py b/airflow/providers/cloudera/security/cde_security.py new file mode 100644 index 0000000000000..882e63bbffdfd --- /dev/null +++ b/airflow/providers/cloudera/security/cde_security.py @@ -0,0 +1,194 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Handles CDE authentication""" +from abc import ABC, abstractmethod +from datetime import datetime, timedelta +from typing import Dict, Optional, Union +from urllib.parse import urlparse + +import requests + +from airflow.providers.cloudera.model.cdp.cde import VirtualCluster +from airflow.providers.cloudera.security import TokenResponse, submit_request +from airflow.providers.cloudera.security.cdp_security import CdpAuth, CdpSecurityError +from airflow.providers.cloudera.security.token_cache import ( + Cache, + CacheableTokenAuth, + GetAuthTokenError, + TokenCacheStrategy, +) +from airflow.utils.log.logging_mixin import LoggingMixin, logging # type: ignore + +LOG = logging.getLogger(__name__) + + +class BearerAuth(requests.auth.AuthBase): + """Helper class for defining the Bearer token bases authentication mechanism.""" + + def __init__(self, token: str) -> None: + self.token = token + + def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest: # pragma: no cover + # ( since it is executed in the requests.get call) + r.headers["authorization"] = f"Bearer {self.token}" + return r + + +class CdeTokenAuthResponse(TokenResponse, LoggingMixin): + """CDE Token Response object""" + + def __init__(self, access_token: str, expires_in: int): + self.access_token = access_token + self.expires_in = expires_in + + @classmethod + def from_response(cls, response: requests.Response) -> "CdeTokenAuthResponse": + """Factory method for creating a new instance of CdeTokenAuthResponse + from a json formatted valid request response + + Args: + response: Response obtained from the Knox endpoint + + Returns: + New, corresponding instance of CdeTokenAuthResponse + """ + response_json = response.json() + return cls(response_json.get("access_token"), response_json.get("expires_in")) + + def is_valid(self) -> bool: + current_time = datetime.utcnow() + token_time = datetime.utcfromtimestamp(self.expires_in / 1000) + # Making the token invalid earlier to avoid edge cases when current time is too close + # from the token expiry time (it would cause that the token would become invalid by the + # time the CDEHook would use it) + max_valid_time = token_time - timedelta(minutes=5) + + LOG.debug( + "Current time is : %s. Token expires at: %s. Token must be renewed from: %s", + current_time, + token_time, + max_valid_time, + ) + + return current_time <= max_valid_time + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CdeTokenAuthResponse): + return False + # Following line for autocompletion + other_response: CdeTokenAuthResponse = other + return ( + self.access_token == other_response.access_token and self.expires_in == other_response.expires_in + ) + + def __repr__(self) -> str: + return ( + f"{CdeTokenAuthResponse.__name__}" + f"{{Token: {self.access_token}, Expires In: {self.expires_in}}}" + ) + + +class CdeAuth(ABC): + """Interface for CDE Authentication""" + + @abstractmethod + def get_cde_authentication_token(self) -> CdeTokenAuthResponse: + """Obtains a CDE access token. + + Returns: + CDE JWT token Response + + Raises: + GetAuthTokenError if it is not possible to retrieve the CDE token + """ + raise NotImplementedError + + +class CdeApiTokenAuth(CdeAuth, CacheableTokenAuth): + """Authentication class for obtaining CDE token from CDP API token""" + + def __init__( + self, + cde_vcluster: VirtualCluster, + cdp_auth: CdpAuth, + token_cache_strategy: Optional[TokenCacheStrategy] = None, + custom_ca_certificate_path: Optional[str] = None, + insecure: Optional[bool] = False, + ) -> None: + self.cde_vcluster = cde_vcluster + self.cdp_auth = cdp_auth + if token_cache_strategy: + super().__init__(token_cache_strategy) + self.custom_ca_certificate_path = custom_ca_certificate_path + self.insecure = insecure + + @Cache(token_response_type=CdeTokenAuthResponse) + def get_cde_authentication_token(self) -> CdeTokenAuthResponse: + return self.fetch_authentication_token() + + def fetch_authentication_token(self) -> CdeTokenAuthResponse: + """Obtains a fresh token directly from the target system + + Returns: + valid fresh token + + Raises: + GetAuthTokenError if there was an issue while retrieving the token + """ + try: + workload_name = "DE" + cdp_token = self.cdp_auth.generate_workload_auth_token(workload_name) + except CdpSecurityError as err: + LOG.error("Could not obtain CDP token %s", err) + raise GetAuthTokenError(err) from err + + # Exchange the CDP access token for a CDE/CDW access token + try: + auth_endpoint = self.cde_vcluster.get_auth_endpoint() + except ValueError as err: + LOG.error("Could not determine authentication endpoint: %s", err) + raise GetAuthTokenError(err) from err + + try: + kw_args: Dict[str, Union[str, bool, BearerAuth]] = { + "auth": BearerAuth(cdp_token.token), + } + + if self.insecure: + kw_args = {**kw_args, "verify": False} + elif self.custom_ca_certificate_path is not None: + kw_args = {**kw_args, "verify": self.custom_ca_certificate_path} + + response = submit_request("GET", auth_endpoint, **kw_args) + except Exception as err: + LOG.error("Could not execute auth request: %s", repr(err)) + raise GetAuthTokenError(err) from err + + cde_token = CdeTokenAuthResponse.from_response(response) + + LOG.info( + "Acquired CDE token expiring at %s", + datetime.fromtimestamp(cde_token.expires_in / 1000), + ) + + return cde_token + + def get_cache_key(self) -> str: + vcluster_url = urlparse(self.cde_vcluster.vcluster_endpoint) + vcluster_host = vcluster_url.netloc + return f"{self.cdp_auth.get_auth_identifier()}____{vcluster_host}" diff --git a/airflow/providers/cloudera/security/cdp_requests/__init__.py b/airflow/providers/cloudera/security/cdp_requests/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/cloudera/security/cdp_requests/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/cloudera/security/cdp_requests/cdpcurl.py b/airflow/providers/cloudera/security/cdp_requests/cdpcurl.py new file mode 100644 index 0000000000000..ea16298e5d4c1 --- /dev/null +++ b/airflow/providers/cloudera/security/cdp_requests/cdpcurl.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""cdpcurl implementation.""" +import datetime +import pprint +import sys +from email.utils import formatdate + +from airflow.providers.cloudera.security import submit_request +from airflow.providers.cloudera.security.cdp_requests.cdpv1sign import make_signature_header + +__author__ = "cloudera" + +IS_VERBOSE = False + + +def __log(*args, **kwargs): + if not IS_VERBOSE: + return + stderr_pp = pprint.PrettyPrinter(stream=sys.stderr) + stderr_pp.pprint(*args, **kwargs) + + +def __now(): + return datetime.datetime.now(datetime.timezone.utc) + + +# pylint: disable=too-many-arguments,too-many-locals +def make_request(method, uri, headers, data, access_key, private_key, data_binary, verify=True): + """ + # Make HTTP request with CDP request signing + + :return: http request object + :param method: str + :param uri: str + :param headers: dict + :param data: str + :param profile: str + :param access_key: str + :param private_key: str + :param data_binary: bool + :param verify: bool + """ + + if "x-altus-auth" in headers: + raise Exception("x-altus-auth found in headers!") + if "x-altus-date" in headers: + raise Exception("x-altus-date found in headers!") + headers["x-altus-date"] = formatdate(timeval=__now().timestamp(), usegmt=True) + headers["x-altus-auth"] = make_signature_header(method, uri, headers, access_key, private_key) + + if data_binary: + return __send_request(uri, data, headers, method, verify) + return __send_request(uri, data.encode("utf-8"), headers, method, verify) + + +def __send_request(uri, data, headers, method, verify): + __log("\nHEADERS++++++++++++++++++++++++++++++++++++") + __log(headers) + + __log("\nBEGIN REQUEST++++++++++++++++++++++++++++++++++++") + __log("Request URL = " + uri) + + response = submit_request(method, uri, headers=headers, data=data, verify=verify) + + __log("\nRESPONSE++++++++++++++++++++++++++++++++++++") + __log("Response code: %d\n" % response.status_code) + + return response diff --git a/airflow/providers/cloudera/security/cdp_requests/cdpv1sign.py b/airflow/providers/cloudera/security/cdp_requests/cdpv1sign.py new file mode 100644 index 0000000000000..97511c4917dc1 --- /dev/null +++ b/airflow/providers/cloudera/security/cdp_requests/cdpv1sign.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Implementation of the CDP API signature specification, V1.""" +import json +from base64 import b64decode, urlsafe_b64encode +from collections import OrderedDict +from urllib.parse import urlparse + +from pure25519 import eddsa + + +def create_canonical_request_string(method, uri, headers, auth_method): + """Create a canonical request string from aspects of the request.""" + headers_of_interest = [] + for header_name in ["content-type", "x-altus-date"]: + found = False + for key in headers: + key_lc = key.lower() + if headers[key] is not None and key_lc == header_name: + headers_of_interest.append(headers[key].strip()) + found = True + if not found: + headers_of_interest.append("") + + # Our signature verification with treat a query with no = as part of the + # path, so we do as well. It appears to be a behavior left to the server + # implementation, and python and our java servlet implementation disagree. + uri_components = urlparse(uri) + path = uri_components.path + if not path: + path = "/" + if uri_components.query and "=" not in uri_components.query: + path += "?" + uri_components.query + + canonical_string = method.upper() + "\n" + canonical_string += "\n".join(headers_of_interest) + "\n" + canonical_string += path + "\n" + canonical_string += auth_method + + return canonical_string + + +def create_signature_string(canonical_string, private_key): + """ + Create the string form of the digital signature of the canonical request + string. + """ + seed = b64decode(private_key) + if len(seed) != 32: + raise Exception("Not an Ed25519 private key!") + public_key = eddsa.publickey(seed) + signature = eddsa.signature(canonical_string.encode("utf-8"), seed, public_key) + return urlsafe_b64encode(signature).strip().decode("utf-8") + + +def create_encoded_authn_params_string(access_key, auth_method): + """Create the base 64 encoded string of authentication parameters.""" + auth_params = OrderedDict() + auth_params["access_key_id"] = access_key + auth_params["auth_method"] = auth_method + encoded_json = json.dumps(auth_params).encode("utf-8") + return urlsafe_b64encode(encoded_json).strip() + + +def create_signature_header(encoded_authn_params, signature): + """ + Combine the encoded authentication parameters string and signature string + into the signature header value. + """ + return f"{encoded_authn_params.decode('utf-8')}.{signature}" + + +def make_signature_header(method, uri, headers, access_key, private_key): + """ + Generates the value to be used for the x-altus-auth header in the service + call. + """ + if len(private_key) != 44: + raise Exception("Only ed25519v1 keys are supported!") + + auth_method = "ed25519v1" + + canonical_string = create_canonical_request_string(method, uri, headers, auth_method) + signature = create_signature_string(canonical_string, private_key) + encoded_authn_params = create_encoded_authn_params_string(access_key, auth_method) + signature_header = create_signature_header(encoded_authn_params, signature) + return signature_header diff --git a/airflow/providers/cloudera/security/cdp_security.py b/airflow/providers/cloudera/security/cdp_security.py new file mode 100644 index 0000000000000..fb2378245ac49 --- /dev/null +++ b/airflow/providers/cloudera/security/cdp_security.py @@ -0,0 +1,203 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Handles CDP authentication""" + +import json +from abc import ABC, abstractmethod +from typing import NamedTuple, Optional + +import requests + +from airflow.providers.cloudera.security import SecurityError, TokenResponse +from airflow.providers.cloudera.security.cdp_requests.cdpcurl import make_request +from airflow.utils.log.logging_mixin import LoggingMixin, logging # type: ignore + +LOG = logging.getLogger(__name__) + + +class CdpSecurityError(SecurityError): + """Root exception for CDP authentication issues""" + + +class GetCrnError(CdpSecurityError): + """Exception used when there is an issue while retrieving the environment CRN""" + + +class CdpApiAError(CdpSecurityError): + """Exception used when there is an issue while interacting with CDP API""" + + +class GetWorkloadAuthTokenError(CdpSecurityError): + """Exception used when there is an issue while retrieving the workload token""" + + +class CdpTokenAuthResponse(TokenResponse): + """CDP Token Response object""" + + def __init__(self, response: requests.Response): + response_dict = json.loads(response.content) + self.token = response_dict.get("token") + self.expires_at = response_dict.get("expiresAt") + + def is_valid(self) -> bool: + raise NotImplementedError + + def __repr__(self) -> str: + return f"{CdpTokenAuthResponse.__name__}" f"{{Token: {self.token}, Expires At: {self.expires_at}}}" + + +class CdpAuth(ABC, LoggingMixin): + """Interface for CDP Authentication""" + + @abstractmethod + def get_auth_identifier(self) -> str: + """Gets the identifier of the connection + + Returns: + identifier of the connection + """ + raise NotImplementedError + + @abstractmethod + def get_auth_secret(self) -> str: + """Gets the secret of the connection + + Returns: + secret of the connection + """ + raise NotImplementedError + + @abstractmethod + def generate_workload_auth_token(self, workload_name: str) -> CdpTokenAuthResponse: + """Obtains a CDP access token. + + Args: + workload_name: kind of workload for which we request the CDP token + + Returns: + CDP JWT token Response + + Raises: + CdpApiAError if it is not possible to retrieve the CDP token + """ + raise NotImplementedError + + +class CdpAccessKeyCredentials(NamedTuple): + """Represent access/private key pair for CDP Access Key V2 authentication""" + + access_key: str + private_key: str + + +class CdpAccessKeyV2TokenAuth(CdpAuth): + """Authentication class for obtaining CDP token from access key/private key credentials""" + + CDP_ENDPOINT_DEFAULT = "https://api.us-west-1.cdp.cloudera.com" + CDP_DESCRIBE_SERVICE_ROUTE = "/api/v1/de/describeService" + ALTUS_IAM_ENDPOINT_DEFAULT = "https://iamapi.us-west-1.altus.cloudera.com" + ALTUS_IAM_GEN_WORKLOAD_AUTH_TOKEN_ROUTE = "/iam/generateWorkloadAuthToken" + + def __init__( + self, + service_id: str, + cdp_cred: CdpAccessKeyCredentials, + cdp_endpoint: Optional[str] = None, + altus_iam_endpoint: Optional[str] = None, + ) -> None: + self.service_id = service_id + self.cdp_cred = cdp_cred + self.cdp_describe_service_endpoint = cdp_endpoint if cdp_endpoint else self.CDP_ENDPOINT_DEFAULT + self.cdp_describe_service_endpoint += self.CDP_DESCRIBE_SERVICE_ROUTE + self.altus_iam_gen_workload_auth_endpoint = ( + altus_iam_endpoint if altus_iam_endpoint else self.ALTUS_IAM_ENDPOINT_DEFAULT + ) + self.altus_iam_gen_workload_auth_endpoint += self.ALTUS_IAM_GEN_WORKLOAD_AUTH_TOKEN_ROUTE + + def generate_workload_auth_token(self, workload_name: str) -> CdpTokenAuthResponse: + LOG.debug("Authenticating with access key: %s", self.cdp_cred.access_key) + LOG.debug("Using Cluster ID: %s", self.service_id) + + # Get the environment-crn + env_crn = self.get_env_crn() + LOG.debug("Using environment-crn %s", env_crn) + + # Exchange the access key for a CDP access token + cdp_token = self._generate_workload_auth_token(env_crn, workload_name) + LOG.debug("Exchanged access key for CDP access token") + + return cdp_token + + def get_env_crn(self) -> str: + """ + Gets the associated environment CRN of the given cluster + + Returns: + environment Cloudera Resource Name + + Raises: + GetCrnError if it is not possible to retrieve the environment CRN + """ + headers = {"Content-Type": "application/json"} + # make_request only accepts a string for the request body + request_body = f'{{"clusterId": "{self.service_id}"}}' + try: + response = make_request( + "POST", + self.cdp_describe_service_endpoint, + headers, + request_body, + self.cdp_cred.access_key, + self.cdp_cred.private_key, + False, + True, + ) + except Exception as err: + LOG.error("Issue while performing request to fetch environment-crn: %s", repr(err)) + raise GetCrnError(err) from err + + environment_crn = response.json().get("service").get("environmentCrn") + return environment_crn + + def get_auth_identifier(self) -> str: + return self.cdp_cred.access_key + + def get_auth_secret(self) -> str: + return self.cdp_cred.private_key + + def _generate_workload_auth_token(self, env_crn: str, workload_name: str) -> CdpTokenAuthResponse: + headers = {"Content-Type": "application/json"} + # make_request only accepts a string for the request body + request_body = f'{{"workloadName": "{workload_name}", "environmentCRN": "{env_crn}"}}' + try: + response = make_request( + "POST", + self.altus_iam_gen_workload_auth_endpoint, + headers, + request_body, + self.cdp_cred.access_key, + self.cdp_cred.private_key, + False, + True, + ) + except Exception as err: + LOG.error("Could not exchange cdp token with access key %s", repr(err)) + raise CdpApiAError(err) from err + + cdp_token = CdpTokenAuthResponse(response) + return cdp_token diff --git a/airflow/providers/cloudera/security/token_cache.py b/airflow/providers/cloudera/security/token_cache.py new file mode 100644 index 0000000000000..22a4f487089c1 --- /dev/null +++ b/airflow/providers/cloudera/security/token_cache.py @@ -0,0 +1,298 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Handles token caching and its various caching mechanisms""" +import base64 +import os +from abc import ABC, abstractmethod +from functools import wraps +from json import JSONDecodeError, dumps, loads +from pathlib import Path +from typing import Callable, Optional, Type + +from cryptography.fernet import Fernet, InvalidToken + +from airflow.providers.cloudera.security import SecurityError, TokenResponse +from airflow.utils.log.logging_mixin import LoggingMixin, logging # type: ignore + +LOG = logging.getLogger(__name__) + + +class CacheError(SecurityError): + """Exception used when there is an issue while interacting with token cache""" + + +class FetchAuthTokenError(SecurityError): + """Exception used when there is an issue while fetching the token from the Cloudera APIs""" + + +class GetAuthTokenError(SecurityError): + """Exception used when there is an issue while getting the token""" + + +class TokenCacheStrategy(ABC, LoggingMixin): + """Base class from which token caching strategies must be created. + A Fernet based encryptor is available for encrypting cache content if necessary + """ + + def __init__(self, token_response_class: Type[TokenResponse], encryption_key: Optional[str]) -> None: + self.token_response_class = token_response_class + if encryption_key: + fernet_encrytion_key = self.get_fernet_encryption_key(encryption_key) + self.encryptor = Fernet(fernet_encrytion_key) + + @classmethod + def get_fernet_encryption_key(cls, encryption_key: str) -> bytes: + """ + Get valid encryption key from candidate encryption key + if it does not fit Fernet's module requirements (32 characters and base64 encoding) + """ + if not encryption_key: + raise ValueError("Encryption key cannot be None or empty") + if len(encryption_key) < 32: + raise ValueError("Encryption key is too short. It must be at least 32 characters.") + if len(encryption_key) > 32: + final_encryption_key = encryption_key[:32] + LOG.debug("Encryption key is too long. Truncating to right size") + else: + final_encryption_key = encryption_key + LOG.debug("Encryption key has right size.") + + # Truncate in case of accented letters. + # TODO: If need to be more generic, we should handle that the input encryption key + # can be in other encoding than utf-8. Currently not an issue. + final_encryption_key_bytes = final_encryption_key.encode("utf-8")[:32] + base64key = base64.urlsafe_b64encode(final_encryption_key_bytes) + + return base64key + + @abstractmethod + def get_cached_auth_token(self, cache_key: str) -> TokenResponse: + """Gets token from the cache + + Args: + cache_key: cache key to retrieve + + Returns: + token associated to the cache key + + Raises: + CacheError if it cannot be obtained from the cache + """ + raise NotImplementedError + + @abstractmethod + def cache_auth_token(self, cache_key: str, token: TokenResponse) -> None: + """Caches the token and associates it to the given cache key + + Args: + cache_key: cache key used to store the token + token: the token to cache + + Raises: + CacheError if the token cannot be cached + """ + raise NotImplementedError + + @abstractmethod + def clear_cached_auth_token(self, cache_key: str) -> None: + """Deletes token associated to the cache key from the cache + + Args: + cache_key: cache key to clear + + Raises: + CacheError if the cache entry cannot be deleted + """ + raise NotImplementedError + + +class EncryptedFileTokenCacheStrategy(TokenCacheStrategy): + """ + File based caching mechanism. A file is created for each cache entry. + Content of the cache is encrypted + """ + + CACHE_SUB_DIR = "token_cache" + + def __init__( + self, + token_response_class: Type[TokenResponse], + encryption_key: str, + cache_dir: Optional[str] = ".", + ) -> None: + super().__init__(token_response_class, encryption_key=encryption_key) + try: + if cache_dir and cache_dir.strip(): + self.cache_dir = Path(cache_dir.strip()) + if not os.path.isdir(self.cache_dir): + raise ValueError(f"Cache dir {self.cache_dir} is not a directory.") + self.cache_dir = self.cache_dir / Path(self.CACHE_SUB_DIR) + LOG.debug("Creating directory %s", self.cache_dir) + try: + self.cache_dir.mkdir(mode=0o770, exist_ok=True) + LOG.debug("Directory created successfully") + except Exception as err: + LOG.error("Failed to create %s", self.cache_dir) + raise CacheError( + err, + msg=f"Cache Directory {self.cache_dir} could not be created", + ) from err + else: + LOG.error("No value defined for cache_dir") + raise ValueError("Cache dir is empty") + except ValueError as err: + LOG.error("Failed to initialize the caching mechanism") + raise CacheError(err, msg="Cache Directory and Cache Keys must be specified") from err + self.cache_encoding = "utf-8" + + def get_cache_path(self, cache_key: str) -> Path: + """Cache path for associated cache key + + Args: + cache_key: the cache key + + Returns: + Absolute cache path + """ + if cache_key: + return (self.cache_dir / Path(cache_key)).absolute() + raise ValueError("Cache key must not be empty") + + def get_cached_auth_token(self, cache_key: str) -> TokenResponse: + # Read raw content + try: + with open(self.get_cache_path(cache_key)) as cache_file: + content = cache_file.read().splitlines() + except Exception as err: + raise CacheError( + err, f"Cache file {cache_key} does not exist or issues while reading it" + ) from err + + # Decrypt content + try: + content_dict = loads(self.encryptor.decrypt(content[0].encode(self.cache_encoding))) + token = self.token_response_class(**content_dict) # type: ignore + return token + except InvalidToken as err: + raise CacheError( + err, + "Issue while decrypting cache content. Please check if the file is corrupted.", + ) from err + except IndexError as err: + raise CacheError(err, f"Issues while reading cache {cache_key}") from err + except (TypeError, JSONDecodeError) as err: + raise CacheError(err, "Malformed cache token. Please check if the file is corrupted.") from err + + def cache_auth_token(self, cache_key: str, token: TokenResponse) -> None: + try: + with open(self.get_cache_path(cache_key), "w") as cache_file: + serialized_unencrypted_token_bytes = dumps(token.__dict__).encode(self.cache_encoding) + serialized_encrypted_token_bytes = self.encryptor.encrypt(serialized_unencrypted_token_bytes) + serialized_encrypted_token = serialized_encrypted_token_bytes.decode(self.cache_encoding) + cache_file.write(serialized_encrypted_token) + except Exception as err: + raise CacheError(err, f"Issues while writing cache to {cache_key}") from err + + def clear_cached_auth_token(self, cache_key: str) -> None: + try: + os.unlink(self.get_cache_path(cache_key)) + except FileNotFoundError: + LOG.info("Cache file does not exist, nothing to clear.") + except Exception as err: + raise CacheError(err, f"Issues while clearing cache to {cache_key}") from err + + +class CacheableTokenAuth: + """Base class for authentications which needs to use caching.""" + + def __init__(self, token_cache_strategy: TokenCacheStrategy) -> None: + self.token_cache_strategy = token_cache_strategy + + def get_cache_key(self) -> str: + """Cache key which will be used to store the token + + Returns: + String representation of the cache key + """ + + +class Cache: + """Decorator for leveraging token caching on a function which fetches a token""" + + def __init__(self, token_response_type) -> None: + self.token_response_type = token_response_type + + def __call__(self, fetch_func: Callable[..., TokenResponse]): + """Gets token from either the cache or the target system. + + If the token from the cache expired it requires a new one from the target system. + + Returns: + A valid token + + Raises: + GetAuthTokenError if there was an issue while getting the token. + """ + + @wraps(fetch_func) + def wrapper( + token_auth: CacheableTokenAuth, *args, **kwargs + ) -> self.token_response_type: # type: ignore + # Attempt to retrieve a cached access token + if isinstance(token_auth.token_cache_strategy, TokenCacheStrategy): + try: + token = token_auth.token_cache_strategy.get_cached_auth_token(token_auth.get_cache_key()) + if token.is_valid(): + + LOG.info( + "%s: %s", + "Using cached token from cache key", + token_auth.get_cache_key(), + ) + + return token + + LOG.info("Acquiring new token: cached token has expired.") + except CacheError as err: + if isinstance(err.raised_from, FileNotFoundError): + LOG.info("Acquiring new token: No cache found") + else: + LOG.warning( + ("Acquiring new token: Issue while reading the cached token." " Reason %s"), + repr(err), + ) + + try: + token = fetch_func(token_auth, *args, **kwargs) + except FetchAuthTokenError as err: + LOG.error("Could not obtain authentication token. Reason: %s", repr(err)) + raise GetAuthTokenError(err) from err + + if isinstance(token_auth.token_cache_strategy, TokenCacheStrategy): + # Cache the token + try: + token_auth.token_cache_strategy.cache_auth_token(token_auth.get_cache_key(), token) + except CacheError as err: + LOG.warning( + "%s: %s", + "Failed to cache authentication token. Reason: ", + repr(err), + ) + return token + + return wrapper diff --git a/airflow/providers/cloudera/sensors/__init__.py b/airflow/providers/cloudera/sensors/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/cloudera/sensors/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/cloudera/sensors/cdw_sensor.py b/airflow/providers/cloudera/sensors/cdw_sensor.py new file mode 100644 index 0000000000000..6d8aa2c44a47b --- /dev/null +++ b/airflow/providers/cloudera/sensors/cdw_sensor.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.providers.apache.hive.sensors.hive_partition import HivePartitionSensor +from airflow.providers.cloudera.hooks.cdw_hook import CdwHiveMetastoreHook + + +class CdwHivePartitionSensor(HivePartitionSensor): + """ + CdwHivePartitionSensor is a subclass of HivePartitionSensor and supposed to implement + the same logic by delegating the actual work to a CdwHiveMetastoreHook instance. + """ + + template_fields = ( + "schema", + "table", + "partition", + ) + ui_color = "#C5CAE9" + + def __init__( + self, + table, + partition="ds='{{ ds }}'", + cli_conn_id="metastore_default", + schema="default", + poke_interval=60 * 3, + *args, + **kwargs, + ): + super().__init__(table=table, poke_interval=poke_interval, *args, **kwargs) + if not partition: + partition = "ds='{{ ds }}'" + self.cli_conn_id = cli_conn_id + self.table = table + self.partition = partition + self.schema = schema + self.hook = None + + def poke(self, context): + if "." in self.table: + self.schema, self.table = self.table.split(".") + self.log.info( + "Poking for table %s.%s, partition %s", + self.schema, + self.table, + self.partition, + ) + if self.hook is None: + self.hook = CdwHiveMetastoreHook(cli_conn_id=self.cli_conn_id) + return self.hook.check_for_partition(self.schema, self.table, self.partition) diff --git a/airflow/providers/dependencies.json b/airflow/providers/dependencies.json index 79a58e5bc2ef1..b5eaae2a611eb 100644 --- a/airflow/providers/dependencies.json +++ b/airflow/providers/dependencies.json @@ -30,6 +30,10 @@ "apache.livy": [ "http" ], + "cloudera": [ + "apache.hive", + "http" + ], "dbt.cloud": [ "http" ], diff --git a/docs/apache-airflow-providers-cloudera/commits.rst b/docs/apache-airflow-providers-cloudera/commits.rst new file mode 100644 index 0000000000000..5f3de0295f28e --- /dev/null +++ b/docs/apache-airflow-providers-cloudera/commits.rst @@ -0,0 +1,27 @@ + + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Package apache-airflow-providers-cloudera +------------------------------------------------------ + +`Cloudera `__ + + +This is detailed commit list of changes for versions provider package: ``cloudera``. +For high-level changelog, see :doc:`package information including changelog `. diff --git a/docs/apache-airflow-providers-cloudera/connections/index.rst b/docs/apache-airflow-providers-cloudera/connections/index.rst new file mode 100644 index 0000000000000..1378fe432d8ea --- /dev/null +++ b/docs/apache-airflow-providers-cloudera/connections/index.rst @@ -0,0 +1,59 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + + +.. _howto/connection:cloudera: + +Cloudera Data Engineering Connection +==================================== + +The Cloudera Data Engineering connection type enables integrations with Cloudera Data Engineering. + +Authenticating to Cloudera Data Engineering +------------------------------------------- + +Cloudera Data Engineering relies on CDP Authentication mechanisms. +The Cloudera Data Engineering connection relies currently only on CDP Access/Private Key pair, +which can be set up in the user's profile as described here: + +https://docs.cloudera.com/cdp/latest/cli/topics/mc-cli-generating-an-api-access-key.html + + +Default Connection IDs +---------------------- + +Hooks and operators related to Cloudera Data Engineering use ``cde_runtime_api`` by default. + +Configuring the Connection +-------------------------- + +Host (required) + Specify the Virtual Cluster Jobs Api URL that can be obtained with the following steps: + - From the CDE home page, go to Overview > Virtual Clusters > Cluster Details of the Virtual Cluster (VC) where you want the CDE job to run. + - Click JOBS API URL to copy the URL. + +CDP Access Key (required) + Provide a CDP access key of the account for running jobs on the CDE VC. + +CDP Private Key (required) + Provide the CDP private key associated to the given CDP Access Key. + +Extra (optional) + Specify the extra parameter (as json dictionary) that can be used in the Cloudera Data Engineering connection. + + * ``cache_dir``: if for some reason the default cache_directory cannot be used because of insufficient access rights diff --git a/docs/apache-airflow-providers-cloudera/index.rst b/docs/apache-airflow-providers-cloudera/index.rst new file mode 100644 index 0000000000000..26489720efec2 --- /dev/null +++ b/docs/apache-airflow-providers-cloudera/index.rst @@ -0,0 +1,90 @@ + + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +``apache-airflow-providers-databricks`` +============================================ + +Content +------- + +.. toctree:: + :maxdepth: 1 + :caption: Guides + + Connection types + Operators + Sensors + +.. toctree:: + :maxdepth: 1 + :caption: References + + Python API <_api/airflow/providers/cloudera/index> + +.. toctree:: + :maxdepth: 1 + :caption: Resources + + PyPI Repository + Installing from sources + +.. THE REMAINDER OF THE FILE IS AUTOMATICALLY GENERATED. IT WILL BE OVERWRITTEN AT RELEASE TIME! + + +.. toctree:: + :maxdepth: 1 + :caption: Commits + + Detailed list of commits + + +Package apache-airflow-providers-cloudera +------------------------------------------------------ + +`Cloudera `__ + + +Release: 1.0.0 + +Provider package +---------------- + +This is a provider package for ``cloudera`` provider. All classes for this provider package +are in ``airflow.providers.cloudera`` python package. + +Installation +------------ + +You can install this package on top of an existing Airflow 2.1+ installation via +``pip install apache-airflow-providers-cloudera`` + +PIP requirements +---------------- + +================== ================== +PIP package Version required +================== ================== +``apache-airflow`` ``>=2.0.0`` +``cryptography`` ``>=3.3.2`` +``pathlib`` +``pure25519`` +``requests`` +``tenacity`` +================== ================== + +.. include:: ../../airflow/providers/cloudera/CHANGELOG.rst diff --git a/docs/apache-airflow-providers-cloudera/installing-providers-from-sources.rst b/docs/apache-airflow-providers-cloudera/installing-providers-from-sources.rst new file mode 100644 index 0000000000000..1c90205d15b3a --- /dev/null +++ b/docs/apache-airflow-providers-cloudera/installing-providers-from-sources.rst @@ -0,0 +1,18 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. include:: ../installing-providers-from-sources.rst diff --git a/docs/apache-airflow-providers-cloudera/operators/cde_run_job.rst b/docs/apache-airflow-providers-cloudera/operators/cde_run_job.rst new file mode 100644 index 0000000000000..8e074bd6d4d05 --- /dev/null +++ b/docs/apache-airflow-providers-cloudera/operators/cde_run_job.rst @@ -0,0 +1,110 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + + +.. _howto/operator:CdeRunJobOperator: + + +CdeRunJobOperator +================== + +Use the :class:``~airflow.providers.cloudera.operators.CdeRunJobOperator`` to submit a new CDE job via CDE ``/jobs//run`` API endpoint. + + +Using the Operator +------------------ + + Runs a job in a CDE Virtual Cluster. The ``CdeRunJobOperator`` runs the + named job with optional variables and overrides. The job and its resources + must have already been created via the specified virtual cluster jobs API. + + The virtual cluster API endpoint is specified by setting the + ``connection_id`` parameter. The "local" virtual cluster jobs API is the + default and has a special value of ``cde_runtime_api``. Authentication to + the API is handled automatically and any jobs in the DAG will run as the + user who submitted the DAG. + + Jobs can be defined in a virtual cluster with variable placeholders, + e.g. ``{{ inputdir }}``. Currently the fields supporting variable expansion + are Spark application name, Spark arguments, and Spark configurations. + Variables can be passed to the operator as a dictionary of key-value string + pairs. In addition to any user variables passed via the ``variables`` + parameter, the following standard Airflow macros are automatically + populated as variables by the operator (see + https://airflow.apache.org/docs/stable/macros-ref): + + * ``ds``: the execution date as ``YYYY-MM-DD`` + * ``ds_nodash``: the execution date as ``YYYYMMDD`` + * ``ts``: execution date in ISO 8601 format + * ``ts_nodash``: execution date in ISO 8601 format without '-', ':' or + timezone information + * ``run_id``: the run_id of the current DAG run + + If a CDE job needs to run with a different configuration, a task can be + configured with runtime overrides. For example to override the Spark + executor memory and cores for a task and to supply an additional config + parameter you could supply the following dictionary can be supplied to + the ``overrides`` parameter:: + + { + 'spark': { + 'executorMemory': '8g', + 'executorCores': '4', + 'conf': { + 'spark.kubernetes.memoryOverhead': '2048' + } + } + } + + See the CDE Jobs API documentation for the full list of parameters that + can be overridden. + + Via the ``wait`` parameter, jobs can either be submitted asynchronously to + the API (``wait=False``) or the task can wait until the job is complete + before exiting the task (default is ``wait=True``). If ``wait`` is + ``True``, the task exit status will reflect the final status of the + submitted job (or the task will fail on timeout if specified). If ``wait`` + is ``False`` the task status will reflect whether the job was successfully + submitted to the API or not. + + Note: all parameters below can also be provided through the + ``default_args`` field of the DAG. + + +.. list-table:: + :widths: 25 25 + :header-rows: 1 + + * - Parameter + - Input + * - job_name: str + - The name of the job in the target cluster, required + * - connection_id: str + - The Airflow connection id for the target API endpoint, default value ``'cde_runtime_api'`` + * - variables: dict + - A dictionary of key-value pairs to populate in the job configuration, default empty dict. + * - overrides: dict + - A dictionary of key-value pairs to override in the job configuration, default empty dict. + * - wait: bool + - If set to true, the operator will wait for the job to complete in the target cluster. The task exit status will reflect the status of the completed job. Default ``True`` + * - timeout: int + - The maximum time to wait in seconds for the job to complete if ``wait=True``. If set to ``None``, 0 or a negative number, the task will never be timed out. Default ``0``. + * - job_poll_interval: int + - The interval in seconds at which the target API is polled for the job status. Default ``10``. + * - api_retries: int + - The number of times to retry an API request in the event of a connection failure or non-fatal API error. Default ``9``. diff --git a/docs/apache-airflow-providers-cloudera/operators/execute_query.rst b/docs/apache-airflow-providers-cloudera/operators/execute_query.rst new file mode 100644 index 0000000000000..e70c81b51f30d --- /dev/null +++ b/docs/apache-airflow-providers-cloudera/operators/execute_query.rst @@ -0,0 +1,52 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + + +.. _howto/operator:CdwExecuteQueryOperator: + + +CdwExecuteQueryOperator +======================= + +Use the :class:`~airflow.providers.cloudera.operators.CdwExecuteQueryOperator` to execute hql code in CDW. + + +Using the Operator +------------------ + +Executes hql code in CDW. This class inherits behavior from HiveOperator, and instantiates a CdwHook to do the work. + + +.. list-table:: + :widths: 25 25 + :header-rows: 1 + + * - Parameter + - Input + * - schema: str + - The name of the DB schema, default value ``'default'`` + * - hiveconfs: dict + - An optional dictionary of key-value pairs to define hive configurations + * - hiveconf_jinja_translate: bool + - default value ``False``. + * - cli_conn_id: str + - The Airflow connection id for the target CDW instance, default value ``'hive_cli_default'`` + * - jdbc_driver: str + - Package name of the Impala jdbc_driver, for instance "com.cloudera.impala.jdbc41.Driver". Required for Impala connections. None by default. + * - query_isolation: bool + - Controls whether to use cdw's query isolation feature. Only hive warehouses support this at the moment. Default ``True``. diff --git a/docs/apache-airflow-providers-cloudera/operators/index.rst b/docs/apache-airflow-providers-cloudera/operators/index.rst new file mode 100644 index 0000000000000..91da1a7312755 --- /dev/null +++ b/docs/apache-airflow-providers-cloudera/operators/index.rst @@ -0,0 +1,28 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + + +Cloudera Operators +================== + + +.. toctree:: + :maxdepth: 1 + :glob: + + * diff --git a/docs/apache-airflow-providers-cloudera/sensors/hive_partition.rst b/docs/apache-airflow-providers-cloudera/sensors/hive_partition.rst new file mode 100644 index 0000000000000..ad7f25bd0f04a --- /dev/null +++ b/docs/apache-airflow-providers-cloudera/sensors/hive_partition.rst @@ -0,0 +1,51 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + + +.. _howto/sensor:CdwHivePartitionSensor: + + +CdwHivePartitionSensor +====================== + +Use the :class:`~airflow.providers.cloudera.sensors.CdwHivePartitionSensor` to check the existence of a Hive partition. + + +Using the Sensor +---------------- + +Check the existence of a Hive partition. +CdwHivePartitionSensor is a subclass of HivePartitionSensor and supposed to implement the same logic by delegating the actual work to a CdwHiveMetastoreHook instance. + + +.. list-table:: + :widths: 25 25 + :header-rows: 1 + + * - Parameter + - Input + * - table: str + - The name of the table to wait for, supports the dot notation (my_database.my_table) + * - partition: str + - Name of the hive partition to check. Default ``"ds='{{ ds }}'"`` + * - cli_conn_id: str + - The Airflow connection id for the target CDW instance, default value ``'metastore_default'`` + * - schema: str + - The name of the DB schema, default value ``'default'`` + * - poke_interval: float + - Time in seconds that the job should wait in between each tries. Default `'60 * 3`` diff --git a/docs/apache-airflow/extra-packages-ref.rst b/docs/apache-airflow/extra-packages-ref.rst index ad4fa01522217..9b250586ea120 100644 --- a/docs/apache-airflow/extra-packages-ref.rst +++ b/docs/apache-airflow/extra-packages-ref.rst @@ -162,6 +162,8 @@ Those are extras that add dependencies needed for integration with external serv +---------------------+-----------------------------------------------------+-----------------------------------------------------+ | cloudant | ``pip install 'apache-airflow[cloudant]'`` | Cloudant hook | +---------------------+-----------------------------------------------------+-----------------------------------------------------+ +| cloudera | ``pip install 'apache-airflow[cloudera]'`` | Cloudera hooks and operators | ++---------------------+-----------------------------------------------------+-----------------------------------------------------+ | databricks | ``pip install 'apache-airflow[databricks]'`` | Databricks hooks and operators | +---------------------+-----------------------------------------------------+-----------------------------------------------------+ | datadog | ``pip install 'apache-airflow[datadog]'`` | Datadog hooks and sensors | diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index cbc6b2abdd7b9..cb50aea1ecfcb 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -67,6 +67,7 @@ Cinimex ClassifyTextResponse CloudTasksClient Cloudant +Cloudera Cloudwatch ClusterManagerClient Codecov @@ -333,6 +334,7 @@ RefreshError ReidentifyContentResponse Reinitialising Remoting +RequestException Reserialize ResourceRequirements Roadmap @@ -463,6 +465,7 @@ alibaba allAuthenticatedUsers allUsers allowinsert +altus amazonaws amqp analyse @@ -499,6 +502,7 @@ attrs auth authMechanism authenticator +authn authorised autoclass autocommit @@ -570,6 +574,9 @@ catchup cattrs ccache cdc +cde +cdpcurl +cdw celeryd celltags cfg @@ -590,9 +597,11 @@ classable classname classpath classpaths +cldr cli clientId cloudant +cloudera cloudml cloudsqldatabehook cloudwatch @@ -747,6 +756,7 @@ dsn dttm dtypes durations +dwx dylib dynamodb dynload @@ -886,6 +896,7 @@ hostnames hotfix howto hql +hs html htmlcontent http @@ -985,6 +996,7 @@ kylin lastname latencies latin +lbodor ldap ldaps leveldb @@ -1482,6 +1494,7 @@ uris url urlencoded urlparse +urls useHCatalog useLegacySQL useQueryCache @@ -1497,6 +1510,7 @@ utils uuid validator vals +vcluster ve vendored venvs @@ -1530,6 +1544,7 @@ www xcom xcomarg xcomresult +xkg xml xpath xyz diff --git a/setup.py b/setup.py index 7935b7663a7ae..f59581e68c911 100644 --- a/setup.py +++ b/setup.py @@ -258,6 +258,14 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version cloudant = [ 'cloudant>=2.0', ] +cloudera = [ + 'pure25519', + # Minimum crypto version for handling CVE-2020-36242 + 'cryptography>=3.3.2', + 'tenacity', + 'requests', + 'pathlib', +] dask = [ # Dask support is limited, we need Dask team to upgrade support for dask if we were to continue # Supporting it in the future @@ -691,6 +699,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version 'asana': asana, 'celery': celery, 'cloudant': cloudant, + 'cloudera': cloudera, 'cncf.kubernetes': kubernetes, 'databricks': databricks, 'datadog': datadog, diff --git a/tests/providers/cloudera/__init__.py b/tests/providers/cloudera/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/cloudera/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/cloudera/hooks/__init__.py b/tests/providers/cloudera/hooks/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/cloudera/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/cloudera/hooks/test_cde_hook.py b/tests/providers/cloudera/hooks/test_cde_hook.py new file mode 100644 index 0000000000000..e09483f4982a3 --- /dev/null +++ b/tests/providers/cloudera/hooks/test_cde_hook.py @@ -0,0 +1,225 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Unit Tests for CdeHook related operations""" + +import unittest +from unittest import mock + +from requests import Session + +from airflow.exceptions import AirflowException +from airflow.hooks.base_hook import BaseHook +from airflow.models import Connection +from airflow.providers.cloudera.hooks.cde_hook import CdeHook, CdeHookException +from airflow.providers.cloudera.security.cde_security import CdeApiTokenAuth, CdeTokenAuthResponse +from airflow.utils.log.logging_mixin import LoggingMixin, logging # type: ignore +from tests.providers.cloudera.utils import _get_call_arguments, _make_response + +LOG = logging.getLogger(__name__) +LOG.setLevel(logging.DEBUG) + +TEST_HOST = 'https://vc1.cde-2.cdp-3.cloudera.site' +TEST_SCHEME = 'http' +TEST_PORT = 9090 +TEST_JOB_NAME = 'testjob' +TEST_VARIABLES = { + 'var1': 'someval_{{ ds_nodash }}', + 'ds': '2020-11-25', + 'ds_nodash': '20201125', + 'ts': '2020-11-25T00:00:00+00:00', + 'ts_nodash': '20201125T000000', + 'run_id': 'runid', +} +TEST_OVERRIDES = {'spark': {'conf': {'myparam': 'val_{{ ds_nodash }}'}}} +TEST_AK = "access_key" +TEST_PK = "private_key_xxxxx_xxxxx_xxxxx_xxxxx" +TEST_CUSTOM_CA_CERTIFICATE = "/ca_cert/letsencrypt-stg-root-x1.pem" +TEST_EXTRA = f'{{"ca_cert_path": "{TEST_CUSTOM_CA_CERTIFICATE}"}}' + + +def _get_test_connection(**kwargs): + kwargs = {**TEST_DEFAULT_CONNECTION_DICT, **kwargs} + return Connection(**kwargs) + + +TEST_DEFAULT_CONNECTION_DICT = { + 'conn_id': CdeHook.DEFAULT_CONN_ID, + 'conn_type': 'http', + 'host': TEST_HOST, + 'login': TEST_AK, + 'password': TEST_PK, + 'port': TEST_PORT, + 'schema': TEST_SCHEME, + 'extra': TEST_EXTRA, +} + +TEST_DEFAULT_CONNECTION = _get_test_connection() + +VALID_CDE_TOKEN = "my_cde_token" +VALID_CDE_TOKEN_AUTH_REQUEST_RESPONSE = _make_response( + 200, {"access_token": VALID_CDE_TOKEN, "expires_in": 123}, "" +) +VALID_CDE_TOKEN_AUTH_RESPONSE = CdeTokenAuthResponse.from_response(VALID_CDE_TOKEN_AUTH_REQUEST_RESPONSE) + + +class CdeHookTest(unittest.TestCase, LoggingMixin): + """Unit tests for CdeHook""" + + @mock.patch.object( + BaseHook, 'get_connection', return_value=_get_test_connection(extra='{"insecure": False}') + ) + def test_wrong_extra_in_connection(self, connection_mock): + """Test when wrong input is provided in the extra field of the connection""" + with self.assertRaises(ValueError): + CdeHook() + connection_mock.assert_called() + + @mock.patch( + 'airflow.providers.cloudera.security.cde_security.CdeApiTokenAuth.get_cde_authentication_token', + return_value=VALID_CDE_TOKEN_AUTH_RESPONSE, + ) + @mock.patch.object(Session, 'send', return_value=_make_response(201, {'id': 10}, "")) + @mock.patch.object(BaseHook, 'get_connection', return_value=TEST_DEFAULT_CONNECTION) + def test_submit_job_ok(self, connection_mock, session_send_mock, cde_mock): + """Test a successful submission to the API""" + cde_hook = CdeHook() + run_id = cde_hook.submit_job(TEST_JOB_NAME) + self.assertEqual(run_id, 10) + cde_mock.assert_called() + connection_mock.assert_called() + session_send_mock.assert_called() + + @mock.patch( + 'airflow.providers.cloudera.security.cde_security.CdeApiTokenAuth.get_cde_authentication_token', + return_value=VALID_CDE_TOKEN_AUTH_RESPONSE, + ) + @mock.patch.object(Session, 'send', return_value=_make_response(201, {'id': 10}, "")) + @mock.patch.object(BaseHook, 'get_connection', return_value=_get_test_connection(host='abc.svc')) + def test_submit_job_ok_internal_connection(self, connection_mock, session_send_mock, cde_mock: mock.Mock): + """Test a successful submission to the API""" + cde_hook = CdeHook() + run_id = cde_hook.submit_job(TEST_JOB_NAME) + self.assertEqual(run_id, 10) + cde_mock.assert_not_called() + connection_mock.assert_called() + session_send_mock.assert_called() + + @mock.patch.object( + CdeApiTokenAuth, 'get_cde_authentication_token', return_value=VALID_CDE_TOKEN_AUTH_RESPONSE + ) + @mock.patch.object(BaseHook, 'get_connection', return_value=TEST_DEFAULT_CONNECTION) + @mock.patch.object( + Session, + 'send', + side_effect=[ + _make_response(503, None, "Internal Server Error"), + _make_response(500, None, "Internal Server Error"), + _make_response(201, {'id': 10}, ""), + ], + ) + def test_submit_job_retry_after_5xx_works(self, send_mock, connection_mock, cde_mock): + """Ensure that 5xx errors are retried""" + cde_hook = CdeHook() + run_id = cde_hook.submit_job(TEST_JOB_NAME) + self.assertEqual(run_id, 10) + self.assertEqual(cde_mock.call_count, 1) + self.assertEqual(send_mock.call_count, 3) + connection_mock.assert_called() + + @mock.patch( + 'airflow.providers.cloudera.security.cde_security.CdeApiTokenAuth.get_cde_authentication_token', + return_value=VALID_CDE_TOKEN_AUTH_RESPONSE, + ) + @mock.patch.object(BaseHook, 'get_connection', return_value=TEST_DEFAULT_CONNECTION) + @mock.patch.object(Session, 'send', return_value=_make_response(404, None, "Not Found")) + def test_submit_job_fails_immediately_for_4xx(self, send_mock, connection_mock, cde_mock): + """Ensure that 4xx errors are _not_ retried""" + cde_hook = CdeHook() + with self.assertRaises(CdeHookException) as err: + cde_hook.submit_job(TEST_JOB_NAME) + self.assertEqual(send_mock.call_count, 1) + self.assertIsInstance(err.exception.raised_from, AirflowException) + cde_mock.assert_called() + connection_mock.assert_called() + + @mock.patch( + 'airflow.providers.cloudera.security.cde_security.CdeApiTokenAuth.get_cde_authentication_token', + return_value=VALID_CDE_TOKEN_AUTH_RESPONSE, + ) + @mock.patch.object(Session, 'send', return_value=_make_response(201, {'id': 10}, "")) + @mock.patch.object( + BaseHook, 'get_connection', return_value=_get_test_connection(extra='{"insecure": true}') + ) + def test_submit_job_insecure(self, connection_mock, session_send_mock, cde_mock): + """Ensure insecure mode is taken into account""" + cde_hook = CdeHook() + run_id = cde_hook.submit_job(TEST_JOB_NAME) + self.assertEqual(run_id, 10) + cde_mock.assert_called() + connection_mock.assert_called() + session_send_mock.assert_called() + called_args = _get_call_arguments(session_send_mock.call_args) + self.assertEqual(called_args['verify'], False) + + @mock.patch( + 'airflow.providers.cloudera.security.cde_security.CdeApiTokenAuth.get_cde_authentication_token', + return_value=VALID_CDE_TOKEN_AUTH_RESPONSE, + ) + @mock.patch.object(Session, 'send', return_value=_make_response(201, {'id': 10}, "")) + @mock.patch.object(BaseHook, 'get_connection', return_value=_get_test_connection(extra='{}')) + def test_submit_job_no_custom_ca_certificate(self, connection_mock, session_send_mock, cde_mock): + """Ensure that default TLS security configuration runs fine""" + cde_hook = CdeHook() + run_id = cde_hook.submit_job(TEST_JOB_NAME) + self.assertEqual(run_id, 10) + cde_mock.assert_called() + connection_mock.assert_called() + session_send_mock.assert_called() + called_args = _get_call_arguments(session_send_mock.call_args) + self.assertEqual(called_args['verify'], True) + + @mock.patch( + 'airflow.providers.cloudera.security.cde_security.CdeApiTokenAuth.get_cde_authentication_token', + return_value=VALID_CDE_TOKEN_AUTH_RESPONSE, + ) + @mock.patch.object(Session, 'send', return_value=_make_response(201, {'id': 10}, "")) + @mock.patch.object(BaseHook, 'get_connection', return_value=TEST_DEFAULT_CONNECTION) + def test_submit_job_custom_ca_certificate(self, connection_mock, session_send_mock, cde_mock): + """Ensure custom is taken into account""" + cde_hook = CdeHook() + run_id = cde_hook.submit_job(TEST_JOB_NAME) + self.assertEqual(run_id, 10) + cde_mock.assert_called() + connection_mock.assert_called() + session_send_mock.assert_called() + called_args = _get_call_arguments(session_send_mock.call_args) + self.assertEqual(called_args['verify'], TEST_CUSTOM_CA_CERTIFICATE) + + @mock.patch.object( + BaseHook, 'get_connection', return_value=_get_test_connection(extra='{"cache_dir": " "}') + ) + def test_wrong_cache_dir(self, connection_mock): + """Ensure that CdeHook object creation fails if cache dir value is wrong""" + cde_hook = CdeHook() + with self.assertRaises(CdeHookException): + cde_hook.submit_job(TEST_JOB_NAME) + connection_mock.assert_called() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/providers/cloudera/hooks/test_cdw_hook.py b/tests/providers/cloudera/hooks/test_cdw_hook.py new file mode 100644 index 0000000000000..238f072310ebc --- /dev/null +++ b/tests/providers/cloudera/hooks/test_cdw_hook.py @@ -0,0 +1,138 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.hooks.base_hook import BaseHook +from airflow.models import Connection +from airflow.providers.cloudera.hooks.cdw_hook import CdwHook + + +def test_beeline_command_hive(mocker): + """ + Tests whether the expected beeline command is generated from CDHHook's parameters. + """ + mocker.patch.object( + BaseHook, + "get_connection", + return_value=Connection( + conn_id='fake', + conn_type='hive_cli', + host='hs2-beeline.host', + login='user', + password='pass', + schema='hello', + port='10001', + extra=None, + uri=None, + ), + ) + hook = CdwHook(cli_conn_id='anything') + beeline_command = hook.get_cli_cmd() + assert ( + ' '.join(beeline_command) == 'beeline -u jdbc:hive2://hs2-beeline.host/hello;' + 'transportMode=http;httpPath=cliservice;ssl=true -n user -p pass ' + '--hiveconf hive.query.isolation.scan.size.threshold=0B ' + '--hiveconf hive.query.results.cache.enabled=false ' + '--hiveconf hive.auto.convert.join.noconditionaltask.size=2505397589' + ), 'invalid beeline command' + + +def test_beeline_command_impala(mocker): + """ + Tests whether the expected beeline command is generated from CDHHook's parameters. + CdwHook will force the following by default in case of impala: + port: 443 (regardless of setting) + AuthMech: should be present, default 3 + """ + mocker.patch.object( + BaseHook, + "get_connection", + return_value=Connection( + conn_id='fake', + conn_type='hive_cli', + host='impala-proxy-beeline.host', + login='user', + password='pass', + schema='hello', + port='7777', + extra=None, + uri=None, + ), + ) + hook = CdwHook(cli_conn_id='anything') + beeline_command = hook.get_cli_cmd() + assert ( + ' '.join(beeline_command) == 'beeline -d com.cloudera.impala.jdbc41.Driver ' + '-u jdbc:impala://impala-proxy-beeline.host:443/hello;AuthMech=3;' + 'transportMode=http;httpPath=cliservice;ssl=1 -n user -p pass' + ), 'invalid beeline command' + + +def test_beeline_command_impala_custom_driver(mocker): + """ + Tests whether the expected beeline command is generated from CDHHook's + parameters with custom impala driver. + """ + mocker.patch.object( + BaseHook, + "get_connection", + return_value=Connection( + conn_id='fake', + conn_type='hive_cli', + host='impala-proxy-beeline.host', + login='user', + password='pass', + schema='hello', + port='7777', + extra=None, + uri=None, + ), + ) + custom_driver = 'com.impala.another.driver' + hook = CdwHook(cli_conn_id='anything', jdbc_driver=custom_driver) + beeline_command = hook.get_cli_cmd() + assert ( + ' '.join(beeline_command) == 'beeline -d ' + custom_driver + ' ' + '-u jdbc:impala://impala-proxy-beeline.host:443/hello;AuthMech=3;' + 'transportMode=http;httpPath=cliservice;ssl=1 -n user -p pass' + ), 'invalid beeline command' + + +def test_beeline_command_non_isolation(mocker): + """ + Tests whether the expected beeline command is generated from CDHHook's parameters without isolation. + """ + mocker.patch.object( + BaseHook, + "get_connection", + return_value=Connection( + conn_id='fake', + conn_type='hive_cli', + host='beeline.host', + login='user', + password='pass', + schema='hello', + port='10001', + extra=None, + uri=None, + ), + ) + hook = CdwHook(cli_conn_id='anything', query_isolation=False) + beeline_command = hook.get_cli_cmd() + assert ( + ' '.join(beeline_command) == 'beeline -u jdbc:hive2://beeline.host/hello;' + 'transportMode=http;httpPath=cliservice;ssl=true -n user -p pass' + ), 'invalid beeline command' diff --git a/tests/providers/cloudera/operators/__init__.py b/tests/providers/cloudera/operators/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/cloudera/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/cloudera/operators/test_cde_job_run_operator.py b/tests/providers/cloudera/operators/test_cde_job_run_operator.py new file mode 100644 index 0000000000000..76d41071f0a05 --- /dev/null +++ b/tests/providers/cloudera/operators/test_cde_job_run_operator.py @@ -0,0 +1,234 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests related to the CDE Job operator""" + +import unittest +from datetime import datetime +from unittest import mock +from unittest.mock import Mock, call + +from airflow.exceptions import AirflowException +from airflow.models.connection import Connection +from airflow.models.dag import DAG +from airflow.providers.cloudera.hooks.cde_hook import CdeHook +from airflow.providers.cloudera.operators.cde_operator import CdeRunJobOperator +from tests.providers.cloudera.utils import _get_call_arguments + +TEST_JOB_NAME = 'testjob' +TEST_JOB_RUN_ID = 10 +TEST_TIMEOUT = 4 +TEST_JOB_POLL_INTERVAL = 1 +TEST_VARIABLES = {'var1': 'someval_{{ ds_nodash }}'} +TEST_OVERRIDES = {'spark': {'conf': {'myparam': 'val_{{ ds_nodash }}'}}} +TEST_CONTEXT = { + 'ds': '2020-11-25', + 'ds_nodash': '20201125', + 'ts': '2020-11-25T00:00:00+00:00', + 'ts_nodash': '20201125T000000', + 'run_id': 'runid', +} + +TEST_HOST = 'vc1.cde-2.cdp-3.cloudera.site' +TEST_SCHEME = 'http' +TEST_PORT = 9090 +TEST_AK = "access_key" +TEST_PK = "private_key" +TEST_CUSTOM_CA_CERTIFICATE = "/ca_cert/letsencrypt-stg-root-x1.pem" +TEST_EXTRA = ( + f'{{"access_key": "{TEST_AK}", "private_key": "{TEST_PK}",' f'"ca_cert": "{TEST_CUSTOM_CA_CERTIFICATE}"}}' +) + +TEST_DEFAULT_CONNECTION_DICT = { + 'conn_id': CdeHook.DEFAULT_CONN_ID, + 'conn_type': 'http', + 'host': TEST_HOST, + 'port': TEST_PORT, + 'schema': TEST_SCHEME, + 'extra': TEST_EXTRA, +} + +TEST_DEFAULT_CONNECTION = Connection( + conn_id=CdeHook.DEFAULT_CONN_ID, + conn_type='http', + host=TEST_HOST, + port=TEST_PORT, + schema=TEST_SCHEME, + extra=TEST_EXTRA, +) + + +@mock.patch.object(CdeHook, 'get_connection', return_value=TEST_DEFAULT_CONNECTION) +class CdeRunJobOperatorTest(unittest.TestCase): + + """Test cases for CDE operator""" + + def test_init(self, get_connection: Mock): + """Test constructor""" + cde_operator = CdeRunJobOperator( + task_id="task", + job_name=TEST_JOB_NAME, + variables=TEST_VARIABLES, + overrides=TEST_OVERRIDES, + ) + get_connection.assert_called() + self.assertEqual(cde_operator.job_name, TEST_JOB_NAME) + self.assertDictEqual(cde_operator.variables, TEST_VARIABLES) + self.assertDictEqual(cde_operator.overrides, TEST_OVERRIDES) + self.assertEqual(cde_operator.connection_id, CdeRunJobOperator.DEFAULT_CONNECTION_ID) + self.assertEqual(cde_operator.wait, CdeRunJobOperator.DEFAULT_WAIT) + self.assertEqual(cde_operator.timeout, CdeRunJobOperator.DEFAULT_TIMEOUT) + self.assertEqual(cde_operator.job_poll_interval, CdeRunJobOperator.DEFAULT_POLL_INTERVAL) + self.assertEqual(cde_operator.api_retries, CdeRunJobOperator.DEFAULT_RETRIES) + + @mock.patch.object(CdeHook, 'submit_job', return_value=TEST_JOB_RUN_ID) + @mock.patch.object(CdeHook, 'check_job_run_status', side_effect=['starting', 'running', 'succeeded']) + def test_execute_and_wait(self, check_job_mock, submit_mock, get_connection): + """Test executing a job run and waiting for success""" + cde_operator = CdeRunJobOperator( + task_id="task", + job_name=TEST_JOB_NAME, + variables=TEST_VARIABLES, + overrides=TEST_OVERRIDES, + timeout=TEST_TIMEOUT, + job_poll_interval=TEST_JOB_POLL_INTERVAL, + ) + get_connection.assert_called() + cde_operator.execute(TEST_CONTEXT) + # Python 3.8 works with called_args = submit_mock.call_args.kwargs, + # but kwargs method is missing in <=3.7.1 + called_args = _get_call_arguments(submit_mock.call_args) + self.assertIsInstance(called_args, dict) + self.assertEqual(dict(called_args['variables'], **TEST_VARIABLES), called_args['variables']) + self.assertEqual(dict(called_args['variables'], **TEST_CONTEXT), called_args['variables']) + self.assertDictEqual(TEST_OVERRIDES, called_args['overrides']) + check_job_mock.assert_has_calls( + [ + call(TEST_JOB_RUN_ID), + call(TEST_JOB_RUN_ID), + call(TEST_JOB_RUN_ID), + ] + ) + + @mock.patch.object(CdeHook, 'submit_job', return_value=TEST_JOB_RUN_ID) + @mock.patch.object(CdeHook, 'check_job_run_status') + def test_execute_and_do_not_wait(self, check_job_mock, submit_mock, get_connection): + """Test executing a job and not waiting""" + cde_operator = CdeRunJobOperator( + task_id="task", + job_name=TEST_JOB_NAME, + variables=TEST_VARIABLES, + overrides=TEST_OVERRIDES, + timeout=TEST_TIMEOUT, + job_poll_interval=TEST_JOB_POLL_INTERVAL, + wait=False, + ) + get_connection.assert_called() + cde_operator.execute(TEST_CONTEXT) + # Python 3.8 works with called_args = submit_mock.call_args.kwargs, + # but kwargs method is missing in <=3.7.1 + called_args = _get_call_arguments(submit_mock.call_args) + self.assertEqual(dict(called_args['variables'], **TEST_VARIABLES), called_args['variables']) + self.assertEqual(dict(called_args['variables'], **TEST_CONTEXT), called_args['variables']) + self.assertDictEqual(TEST_OVERRIDES, called_args['overrides']) + check_job_mock.assert_not_called() + + @mock.patch.object(CdeHook, 'kill_job_run') + def test_on_kill(self, kill_job_mock, get_connection): + """Test killing a running job""" + cde_operator = CdeRunJobOperator( + task_id="task", + job_name=TEST_JOB_NAME, + variables=TEST_VARIABLES, + overrides=TEST_OVERRIDES, + timeout=TEST_TIMEOUT, + job_poll_interval=TEST_JOB_POLL_INTERVAL, + ) + get_connection.assert_called() + cde_operator._job_run_id = 1 # pylint: disable=W0212 + cde_operator.on_kill() + kill_job_mock.assert_called() + + @mock.patch.object(CdeHook, 'check_job_run_status', return_value='starting') + def test_wait_for_job_times_out(self, check_job_mock, get_connection): + """Test a job run timeout""" + cde_operator = CdeRunJobOperator( + task_id="task", + job_name=TEST_JOB_NAME, + timeout=TEST_TIMEOUT, + job_poll_interval=TEST_JOB_POLL_INTERVAL, + ) + get_connection.assert_called() + try: + cde_operator.wait_for_job() + except TimeoutError: + self.assertRaisesRegex(TimeoutError, f'Job run did not complete in {TEST_TIMEOUT} seconds') + check_job_mock.assert_called() + + @mock.patch.object(CdeHook, 'check_job_run_status', side_effect=['failed', 'killed', 'unknown']) + def test_wait_for_job_fails_failed_status(self, check_job_mock, get_connection): + """Test a failed job run""" + cde_operator = CdeRunJobOperator( + task_id="task", + job_name=TEST_JOB_NAME, + timeout=TEST_TIMEOUT, + job_poll_interval=TEST_JOB_POLL_INTERVAL, + ) + get_connection.assert_called() + for status in ['failed', 'killed', 'unknown']: + try: + cde_operator.wait_for_job() + except AirflowException: + self.assertRaisesRegex(AirflowException, f'Job run exited with {status} status') + check_job_mock.assert_called() + + @mock.patch.object(CdeHook, 'check_job_run_status', return_value='not_a_status') + def test_wait_for_job_fails_unexpected_status(self, check_job_mock, get_connection): + """Test an unusual status from API""" + cde_operator = CdeRunJobOperator( + task_id="task", + job_name=TEST_JOB_NAME, + timeout=TEST_TIMEOUT, + job_poll_interval=TEST_JOB_POLL_INTERVAL, + ) + get_connection.assert_called() + try: + cde_operator.wait_for_job() + except AirflowException: + self.assertRaisesRegex( + AirflowException, 'Got unexpected status when polling for job: not_a_status' + ) + check_job_mock.assert_called() + + def test_templating(self, get_connection): + """Test templated fields""" + dag = DAG("dagid", start_date=datetime.now()) + cde_operator = CdeRunJobOperator( + dag=dag, + task_id="task", + job_name=TEST_JOB_NAME, + variables=TEST_VARIABLES, + overrides=TEST_OVERRIDES, + ) + get_connection.assert_called() + cde_operator.render_template_fields(TEST_CONTEXT) + self.assertEqual(dict(cde_operator.variables, **{'var1': 'someval_20201125'}), cde_operator.variables) + self.assertDictEqual(cde_operator.overrides, {'spark': {'conf': {'myparam': 'val_20201125'}}}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/providers/cloudera/security/__init__.py b/tests/providers/cloudera/security/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/cloudera/security/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/cloudera/security/test_security.py b/tests/providers/cloudera/security/test_security.py new file mode 100644 index 0000000000000..5c06387b14e54 --- /dev/null +++ b/tests/providers/cloudera/security/test_security.py @@ -0,0 +1,482 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests related to the API Token Authentication feature""" + +import os +from datetime import datetime, timedelta +from json import JSONDecodeError, dump, dumps +from unittest import TestCase, main +from unittest.mock import Mock, patch + +import requests +from cryptography.fernet import Fernet +from tenacity import wait_none + +from airflow.providers.cloudera.model.cdp.cde import VirtualCluster +from airflow.providers.cloudera.security import ClientError, ServerError, submit_request +from airflow.providers.cloudera.security.cde_security import ( + CdeApiTokenAuth, + CdeTokenAuthResponse, + GetAuthTokenError, +) +from airflow.providers.cloudera.security.cdp_requests.cdpcurl import make_request +from airflow.providers.cloudera.security.cdp_security import ( + CdpAccessKeyCredentials, + CdpAccessKeyV2TokenAuth, + CdpApiAError, + CdpSecurityError, + CdpTokenAuthResponse, + GetCrnError, +) +from airflow.providers.cloudera.security.token_cache import ( + CacheError, + EncryptedFileTokenCacheStrategy, + TokenCacheStrategy, +) +from airflow.utils.log.logging_mixin import LoggingMixin, logging # type: ignore +from tests.providers.cloudera.utils import _get_call_arguments, _make_response, iter_len_plus_one + +LOG = logging.getLogger(__name__) +LOG.setLevel(logging.DEBUG) + +# Speed up tests when retry mechanism is used in failing requests +submit_request.retry.wait = wait_none() # type: ignore + + +TEST_SERVICE_ID = "cluster-5f95z6zc" +TEST_AK = "access_key" +TEST_PK = "private_key_xxxxx_xxxxx_xxxxx_xx" +TEST_ENC_KEY = Fernet(TokenCacheStrategy.get_fernet_encryption_key(TEST_PK)) +CDP_AUTH_AKV2_TEST: CdpAccessKeyV2TokenAuth = CdpAccessKeyV2TokenAuth( + TEST_SERVICE_ID, CdpAccessKeyCredentials(TEST_AK, TEST_PK) +) +TEST_VC_HOST = "k7s2ktbd.cde-5f95z6zc.dex-dev.xcu2-8y8x.dev.cldr.work" +TEST_VC = VirtualCluster(f"https://{TEST_VC_HOST}/dex") +TEST_CACHE_KEY = f"{TEST_AK}____{TEST_VC_HOST}" +TEST_CACHE_KEY_PATH = f"{EncryptedFileTokenCacheStrategy.CACHE_SUB_DIR}/{TEST_CACHE_KEY}" +TEST_CDE_AUTH_CACHE_STRATEGY = EncryptedFileTokenCacheStrategy( + CdeTokenAuthResponse, encryption_key=TEST_PK, cache_dir="." +) +CDE_AUTH_AKV2_TEST = CdeApiTokenAuth(TEST_VC, CDP_AUTH_AKV2_TEST, TEST_CDE_AUTH_CACHE_STRATEGY) +VALID_CDE_TOKEN = "my_cde_token" +# needs to multiply by 1000 to simulate more precision on the epoch time +# since the CDE API gives back in this format +VALID_CDE_TOKEN_RESPONSE_BODY = { + "access_token": VALID_CDE_TOKEN, + "expires_in": (datetime.now() + timedelta(hours=1)).timestamp() * 1000, +} +VALID_CDE_TOKEN_AUTH_REQUEST_RESPONSE = _make_response(200, VALID_CDE_TOKEN_RESPONSE_BODY, "") +VALID_CDE_TOKEN_AUTH_RESPONSE = CdeTokenAuthResponse.from_response(VALID_CDE_TOKEN_AUTH_REQUEST_RESPONSE) +VALID_CDP_TOKEN = "my_cdp_token" +VALID_CDP_TOKEN_AUTH_REQUEST_RESPONSE = _make_response(200, {"token": VALID_CDP_TOKEN, "expiresAt": 2345}, "") +VALID_CDP_TOKEN_AUTH_RESPONSE = CdpTokenAuthResponse(VALID_CDP_TOKEN_AUTH_REQUEST_RESPONSE) + + +class CDPRequestsTestCase(TestCase, LoggingMixin): + """Tests Related to the requests issues to the CDP API""" + + @patch( + 'airflow.providers.cloudera.security.cdp_requests.cdpcurl.make_signature_header', + return_value="signature", + ) + @patch( + 'airflow.providers.cloudera.security.requests.request', + side_effect=[ + requests.exceptions.Timeout, + _make_response(500, None, "Internal Server Error"), + _make_response(401, None, "Unauthorized"), + ], + ) + def test_make_request_with_client_issue(self, make_request_mock: Mock, make_sig_mock: Mock): + """Checks only ClientError are raised in case of client side (4xx) issues.""" + with self.assertRaises(ClientError): + headers = {'Content-Type': 'application/json'} + request_body = "" + make_request( + "POST", + "altus_iam_gen_workload_auth_endpoint", + headers, + request_body, + TEST_AK, + TEST_PK, + False, + True, + ) + make_request_mock.assert_called() + make_sig_mock.assert_called() + + @patch( + 'airflow.providers.cloudera.security.cdp_requests.cdpcurl.make_signature_header', + return_value="signature", + ) + @patch( + 'airflow.providers.cloudera.security.requests.request', + side_effect=[ + _make_response(500, None, "Internal Server Error"), + requests.exceptions.Timeout, + _make_response(500, None, "Internal Server Error"), + ], + ) + def test_make_request_with_server_issue(self, make_request_mock: Mock, make_sig_mock: Mock): + """Checks that ServerError is raised when the request reaches the max retry count + and a 5xx error is returned""" + with self.assertRaises(ServerError): + headers = {'Content-Type': 'application/json'} + request_body = "" + make_request( + "POST", + "altus_iam_gen_workload_auth_endpoint", + headers, + request_body, + TEST_AK, + TEST_PK, + False, + True, + ) + make_request_mock.assert_called() + make_sig_mock.assert_called() + + @patch( + 'airflow.providers.cloudera.security.cdp_requests.cdpcurl.make_signature_header', + return_value="signature", + ) + @patch( + 'airflow.providers.cloudera.security.requests.request', + side_effect=[ + _make_response(500, None, "Internal Server Error"), + requests.exceptions.Timeout, + requests.exceptions.Timeout, + ], + ) + def test_make_request_with_request_issue(self, make_request_mock: Mock, make_sig_mock: Mock): + """Checks that ServerError is raised when the request reaches the max retry count + and a Timeout error is returned""" + with self.assertRaises(requests.exceptions.Timeout): + headers = {'Content-Type': 'application/json'} + request_body = "" + make_request( + "POST", + "altus_iam_gen_workload_auth_endpoint", + headers, + request_body, + TEST_AK, + TEST_PK, + False, + True, + ) + make_request_mock.assert_called() + make_sig_mock.assert_called() + + +class CDPAUthTokenV2TestCase(TestCase): + """Tests related to CDP auth token v2 acquisition""" + + def test_get_auth_identifier(self): + """Identifier must be the access key""" + self.assertEqual(CDP_AUTH_AKV2_TEST.get_auth_identifier(), TEST_AK) + + def test_get_auth_secret(self): + """Secret must be the private key""" + self.assertEqual(CDP_AUTH_AKV2_TEST.get_auth_secret(), TEST_PK) + + @patch( + 'airflow.providers.cloudera.security.cdp_security.make_request', + return_value=_make_response(200, {"service": {"environmentCrn": "my_env_crn"}}, ""), + ) + def test_get_env_crn(self, make_request_mock: Mock): + """Correct environment shall be returned""" + env_crn = CDP_AUTH_AKV2_TEST.get_env_crn() + self.assertEqual("my_env_crn", env_crn) + make_request_mock.assert_called() + + @patch( + 'airflow.providers.cloudera.security.cdp_security.make_request', + side_effect=[requests.exceptions.Timeout, ServerError, ClientError], + ) + def test_get_env_crn_with_issue_in_request(self, make_request_mock: Mock): + """Check error handling of various issues when trying to get CRN""" + for _ in range(iter_len_plus_one(make_request_mock.side_effect)): + with self.assertRaises(GetCrnError): + CDP_AUTH_AKV2_TEST.get_env_crn() + make_request_mock.assert_called() + + @patch.object(CdpAccessKeyV2TokenAuth, 'get_env_crn', return_value="my_env_crn") + @patch( + 'airflow.providers.cloudera.security.cdp_security.make_request', + return_value=VALID_CDP_TOKEN_AUTH_REQUEST_RESPONSE, + ) + def test_generate_workload_auth_token(self, make_request_mock: Mock, env_crn_mock: Mock): + """Check that token is obtained in case of valid API response""" + cdp_token = CDP_AUTH_AKV2_TEST.generate_workload_auth_token("DE") + self.assertEqual("my_cdp_token", cdp_token.token) + env_crn_mock.assert_called() + self.assertEqual(make_request_mock.call_count, 1) + + @patch.object(CdpAccessKeyV2TokenAuth, 'get_env_crn', return_value="my_env_crn") + @patch( + 'airflow.providers.cloudera.security.cdp_security.make_request', + side_effect=[requests.exceptions.RequestException, ServerError, ClientError], + ) + def test_generate_workload_auth_token_with_issue_in_request( + self, make_request_mock: Mock, env_crn_mock: Mock + ): + """Check error handling of various issues when trying to get CDP Auth token""" + for i in range(1, iter_len_plus_one(make_request_mock.side_effect), 1): + with self.assertRaises(CdpSecurityError) as err: + CDP_AUTH_AKV2_TEST.generate_workload_auth_token("DE") + self.assertEqual(env_crn_mock.call_count, 5) + self.assertEqual(env_crn_mock.call_count, i) + self.assertEqual(make_request_mock.call_count, i) + self.assertEqual(type(err.exception), CdpApiAError) + + +class CDETestCase(TestCase): + """Tests for CDE Model related objects""" + + def test_get_cache_key(self): + """Check that cache key for a VC is ____""" + self.assertEqual(CDE_AUTH_AKV2_TEST.get_cache_key(), TEST_CACHE_KEY) + + def test_get_service_id_from_valid_url(self): + """Check that service id can be properly extracted from a valid vc endpoint""" + valid_url = "https://k7s2ktbd.cde-5f95z6zc.dex-dev.xcu2-8y8x.dev.cldr.work/dex/api/v1" + self.assertEqual("cluster-5f95z6zc", VirtualCluster(valid_url).get_service_id()) + + def test_get_cluster_id_from_invalid_url(self): + """Check error handling when trying to extract service id from invalid VC endpoints""" + invalid_urls = ["", "invalid_url"] + for url in invalid_urls: + with self.subTest(url): + with self.assertRaises(ValueError): + VirtualCluster(url).get_service_id() + + def test_get_auth_endpoint(self): + """Check that the auth endpoint can be obtained from the VC Endpoint""" + valid_endpoint = "https://k7s2ktbd.cde-5f95z6zc.dex-dev.xcu2-8y8x.dev.cldr.work/dex/api/v1" + expected = ( + "https://service.cde-5f95z6zc.dex-dev.xcu2-8y8x.dev.cldr.work" + f"{VirtualCluster.ACCESS_KEY_AUTH_ENDPOINT_PATH}" + ) + self.assertEqual(expected, VirtualCluster(valid_endpoint).get_auth_endpoint()) + + def test_get_auth_endpoint_invalid_inputs(self): + """Check error handling when trying to obtain auth endpoint from invalid VC endpoints""" + invalid_endpoints = [ + "", + "invalid_url", + "http://xn-?-vbb/", + "http://k7s2ktbd.cde-5f95z6zc.dex-dev.xcu2-8y8x.dev.cldr.work/dex/api/v1", + "https://service.cde-5f95z6zc.dex-dev.xcu2-8y8x.dev.cldr.work/dex/api/v1", + ] + for url in invalid_endpoints: + with self.subTest(url): + with self.assertRaises(ValueError): + VirtualCluster(url).get_auth_endpoint() + + +class CDEAuthTestCase(TestCase): + """Test cases related to CDE Auth based on CDP Auth TokenV2""" + + @patch.object( + CdpAccessKeyV2TokenAuth, 'generate_workload_auth_token', return_value=VALID_CDP_TOKEN_AUTH_RESPONSE + ) + @patch.object(requests, 'request', return_value=VALID_CDE_TOKEN_AUTH_REQUEST_RESPONSE) + def test_fetch_auth_token(self, get_mock, cdp_mock): + """Token can be acquired successuffully on regular cases, with valid responses""" + cde_token = CDE_AUTH_AKV2_TEST.fetch_authentication_token() + self.assertEqual(cde_token.access_token, VALID_CDE_TOKEN) + get_mock.assert_called() + cdp_mock.assert_called() + + @patch.object( + CdpAccessKeyV2TokenAuth, 'generate_workload_auth_token', return_value=VALID_CDP_TOKEN_AUTH_RESPONSE + ) + @patch.object(requests, 'request', return_value=VALID_CDE_TOKEN_AUTH_REQUEST_RESPONSE) + def test_fetch_auth_token_insecure(self, get_mock: Mock, cdp_mock): + """Insecure mode (no certs check) for request is taken into account""" + cde_auth_akv2_test_insecure = CdeApiTokenAuth( + TEST_VC, CDP_AUTH_AKV2_TEST, TEST_CDE_AUTH_CACHE_STRATEGY, insecure=True + ) + cde_token = cde_auth_akv2_test_insecure.fetch_authentication_token() + self.assertEqual(cde_token.access_token, VALID_CDE_TOKEN) + cdp_mock.assert_called() + get_mock.assert_called() + called_args = _get_call_arguments(get_mock.call_args) + self.assertEqual(called_args['verify'], False) + + @patch.object( + CdpAccessKeyV2TokenAuth, 'generate_workload_auth_token', return_value=VALID_CDP_TOKEN_AUTH_RESPONSE + ) + @patch.object(requests, 'request', return_value=VALID_CDE_TOKEN_AUTH_REQUEST_RESPONSE) + def test_fetch_auth_token_with_custom_ca(self, get_mock, cdp_mock): + """Check that custom ca is used when specified""" + cde_auth_akv2_test_ca = CdeApiTokenAuth( + TEST_VC, CDP_AUTH_AKV2_TEST, TEST_CDE_AUTH_CACHE_STRATEGY, "ca" + ) + cde_token = cde_auth_akv2_test_ca.fetch_authentication_token() + self.assertEqual(cde_token.access_token, VALID_CDE_TOKEN) + get_mock.assert_called() + cdp_mock.assert_called() + called_args = _get_call_arguments(get_mock.call_args) + self.assertEqual(called_args['verify'], "ca") + + @patch.object( + CdpAccessKeyV2TokenAuth, 'generate_workload_auth_token', return_value=VALID_CDP_TOKEN_AUTH_RESPONSE + ) + def test_fetch_auth_token_issue_in_request(self, cdp_mock): + """Check error handling when request to knox fails""" + with self.assertRaises(GetAuthTokenError) as err: + CDE_AUTH_AKV2_TEST.fetch_authentication_token() + cdp_mock.assert_called() + self.assertIsInstance(err.exception.raised_from, requests.RequestException) + + @patch.object( + CdpAccessKeyV2TokenAuth, 'generate_workload_auth_token', return_value=VALID_CDP_TOKEN_AUTH_RESPONSE + ) + @patch.object(requests, 'request', return_value=_make_response(401, {}, "Unauthorized")) + def test_fetch_auth_token_unauthorized(self, get_mock: Mock, cdp_mock: Mock): + """Check error handling when request to knox is not authorized""" + with self.assertRaises(GetAuthTokenError) as err: + CDE_AUTH_AKV2_TEST.fetch_authentication_token() + get_mock.assert_called() + self.assertIsInstance(err.exception.raised_from, ClientError) + cdp_mock.assert_called() + + @patch.object( + CdpAccessKeyV2TokenAuth, 'generate_workload_auth_token', return_value=VALID_CDP_TOKEN_AUTH_RESPONSE + ) + def test_fetch_auth_token_bad_vc_endpoint(self, cdp_mock): + """Check error handling when VC endpoint is wrong""" + cde_auth_akv2_test_bad_vc = CdeApiTokenAuth( + VirtualCluster("bad"), CDP_AUTH_AKV2_TEST, TEST_CDE_AUTH_CACHE_STRATEGY + ) + with self.assertRaises(GetAuthTokenError) as err: + cde_auth_akv2_test_bad_vc.fetch_authentication_token() + cdp_mock.assert_called() + self.assertIsInstance(err.exception.raised_from, ValueError) + + def test_fetch_auth_token_with_cdp_issue(self): + """Check error handling when CDP related operations fail""" + with self.assertRaises(GetAuthTokenError) as err: + CDE_AUTH_AKV2_TEST.fetch_authentication_token() + self.assertIsInstance(err.exception.raised_from, GetCrnError) + + @patch.object( + CdpAccessKeyV2TokenAuth, 'generate_workload_auth_token', return_value=VALID_CDP_TOKEN_AUTH_RESPONSE + ) + @patch.object(requests, 'request', return_value=VALID_CDE_TOKEN_AUTH_REQUEST_RESPONSE) + def test_regular_auth_when_no_or_invalid_cache(self, get_mock, cdp_mock): + """Test when no cache""" + cache_key = CDE_AUTH_AKV2_TEST.get_cache_key() + CDE_AUTH_AKV2_TEST.token_cache_strategy.clear_cached_auth_token(cache_key) + cde_token = CDE_AUTH_AKV2_TEST.get_cde_authentication_token() + + self.assertEqual(cde_token.access_token, VALID_CDE_TOKEN) + get_mock.assert_called() + cdp_mock.assert_called() + + @patch.object( + CdpAccessKeyV2TokenAuth, 'generate_workload_auth_token', return_value=VALID_CDP_TOKEN_AUTH_RESPONSE + ) + def test_valid_cache_auth(self, cdp_mock): + """When cache already exists and is valid""" + token_cache_path = ( + EncryptedFileTokenCacheStrategy.CACHE_SUB_DIR + "/" + CDE_AUTH_AKV2_TEST.get_cache_key() + ) + with open(token_cache_path, 'w') as cache_file: + cache_file.write( + TEST_ENC_KEY.encrypt(dumps(VALID_CDE_TOKEN_RESPONSE_BODY).encode('utf-8')).decode('utf-8') + ) + cde_token = CDE_AUTH_AKV2_TEST.get_cde_authentication_token() + + self.assertEqual(cde_token.access_token, VALID_CDE_TOKEN) + self.assertEqual(cdp_mock.call_count, 0) + + os.remove(token_cache_path) + + +class TokenCacheTestCase(TestCase): + """Test related to generic operations for tokens / caches""" + + def test_filetoken_cache_no_dir(self): + """Error handling when invalid cache dirs are provided""" + with self.assertRaises(CacheError): + EncryptedFileTokenCacheStrategy(CdeTokenAuthResponse, TEST_PK, cache_dir=None) + + with self.assertRaises(CacheError): + EncryptedFileTokenCacheStrategy(CdeTokenAuthResponse, TEST_PK, cache_dir=" ") + + def test_get_cache(self): + """Cache can be retrieved properly and ensure that decrypting the content works + as expected""" + with open(TEST_CACHE_KEY_PATH, 'w') as cache_file: + content = dumps(VALID_CDE_TOKEN_RESPONSE_BODY).encode('utf-8') + dump(TEST_ENC_KEY.encrypt(content).decode('utf-8'), cache_file) + + expected_token = VALID_CDE_TOKEN_AUTH_RESPONSE + actual_token = TEST_CDE_AUTH_CACHE_STRATEGY.get_cached_auth_token(TEST_CACHE_KEY) + self.assertEqual(expected_token, actual_token) + self.assertNotEqual(VALID_CDE_TOKEN_RESPONSE_BODY, actual_token) + + os.remove(TEST_CACHE_KEY_PATH) + + def test_cache_token(self): + """Test that token can be cached and written properly, in an encrpyted manner""" + cde_token = VALID_CDE_TOKEN_AUTH_RESPONSE + ftcs = TEST_CDE_AUTH_CACHE_STRATEGY + ftcs.cache_auth_token(TEST_CACHE_KEY, cde_token) + actual_token = ftcs.get_cached_auth_token(TEST_CACHE_KEY) + self.assertEqual(cde_token, actual_token) + + os.remove(TEST_CACHE_KEY_PATH) + + def test_cannot_get_cache(self): + """Error handling in various situations when trying to obtain cache""" + with open(TEST_CACHE_KEY_PATH, 'w') as cache_file: + cache_file.write('wrong') + with self.assertRaises(CacheError) as err: + TEST_CDE_AUTH_CACHE_STRATEGY.get_cached_auth_token(TEST_CACHE_KEY) + self.assertIsInstance(err.exception.raised_from, JSONDecodeError) + + with open(TEST_CACHE_KEY_PATH, 'w') as cache_file: + cache_file.write('{"xyz": "my_cde_token", "expires_in": 123}') + with self.assertRaises(CacheError): + TEST_CDE_AUTH_CACHE_STRATEGY.get_cached_auth_token(TEST_CACHE_KEY) + self.assertIsInstance(err.exception.raised_from, TypeError) + + os.remove(TEST_CACHE_KEY_PATH) + with self.assertRaises(CacheError) as err: + TEST_CDE_AUTH_CACHE_STRATEGY.get_cached_auth_token(TEST_CACHE_KEY) + self.assertIsInstance(err.exception.raised_from, FileNotFoundError) + + def test_cannot_cache_token(self): + """Error handling in various situations when trying to write cache""" + cde_token = VALID_CDE_TOKEN_AUTH_RESPONSE + with open(TEST_CACHE_KEY_PATH, 'w') as cache_file: + cache_file.write('wrong') + os.chmod(TEST_CACHE_KEY_PATH, 0o400) + ftcs = TEST_CDE_AUTH_CACHE_STRATEGY + with self.assertRaises(CacheError) as err: + ftcs.cache_auth_token(TEST_CACHE_KEY, cde_token) + self.assertIsInstance(err.exception.raised_from, PermissionError) + + os.remove(TEST_CACHE_KEY_PATH) + + +if __name__ == '__main__': + main() diff --git a/tests/providers/cloudera/sensors/__init__.py b/tests/providers/cloudera/sensors/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/cloudera/sensors/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/cloudera/sensors/test_metastore_sensor.py b/tests/providers/cloudera/sensors/test_metastore_sensor.py new file mode 100644 index 0000000000000..9a405808e1640 --- /dev/null +++ b/tests/providers/cloudera/sensors/test_metastore_sensor.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Cloudera Airflow Provider +# (C) Cloudera, Inc. 2021-2022 +# All rights reserved. +# Applicable Open Source License: Apache License Version 2.0 +# +# NOTE: Cloudera open source products are modular software products +# made up of hundreds of individual components, each of which was +# individually copyrighted. Each Cloudera open source product is a +# collective work under U.S. Copyright Law. Your license to use the +# collective work is as provided in your written agreement with +# Cloudera. Used apart from the collective work, this file is +# licensed for your use pursuant to the open source license +# identified above. +# +# This code is provided to you pursuant a written agreement with +# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute +# this code. If you do not have a written agreement with Cloudera nor +# with an authorized and properly licensed third party, you do not +# have any rights to access nor to use this code. +# +# Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the +# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY +# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED +# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO +# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, +# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS +# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE +# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR +# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES +# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF +# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF +# DATA. + +from airflow.providers.cloudera.hooks.cdw_hook import CdwHiveMetastoreHook + + +def test_csv_parse(): + """ + This is just simple validation test for csv reader. The variable beeline_output + contains a sample response which comes from hive in case of --outputformat=csv2. + """ + beeline_output = "db_name,tbl_name,part_name\n" "default,test_part,dt=1" + result_list = CdwHiveMetastoreHook.parse_csv_lines(beeline_output) + + assert len(result_list) == 2, result_list diff --git a/tests/providers/cloudera/utils.py b/tests/providers/cloudera/utils.py new file mode 100644 index 0000000000000..9f82f8adb7fc7 --- /dev/null +++ b/tests/providers/cloudera/utils.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Utils module for common utility methods used in the tests""" +from itertools import tee +from json import dumps + +from requests import Response + + +def iter_len_plus_one(iterator): + """Return the length + 1 of the given iterator. + The +1 is because in the tests the first side effect is already consumed""" + return sum(1 for _ in tee(iterator)) + 1 + + +def _get_call_arguments(self): + if len(self) == 2: + # returned tuple is args, kwargs = self + _, kwargs = self + else: + # returned tuple is name, args, kwargs = self + _, _, kwargs = self + + return kwargs + + +def _make_response(status, body, reason): + resp = Response() + resp.status_code = status + resp._content = dumps(body).encode('utf-8') + resp.reason = reason + return resp