diff --git a/airflow/providers/vertica/hooks/vertica.py b/airflow/providers/vertica/hooks/vertica.py index 92a74ea36901c..06b2e3cf179b7 100644 --- a/airflow/providers/vertica/hooks/vertica.py +++ b/airflow/providers/vertica/hooks/vertica.py @@ -46,5 +46,56 @@ def get_conn(self) -> connect: else: conn_config["port"] = int(conn.port) + bool_options = [ + "connection_load_balance", + "binary_transfer", + "disable_copy_local", + "request_complex_types", + "use_prepared_statements", + ] + std_options = [ + "session_label", + "backup_server_node", + "kerberos_host_name", + "kerberos_service_name", + "unicode_error", + "workload", + "ssl", + ] + conn_extra = conn.extra_dejson + + for bo in bool_options: + if bo in conn_extra: + conn_config[bo] = str(conn_extra[bo]).lower() in ["true", "on"] + + for so in std_options: + if so in conn_extra: + conn_config[so] = conn_extra[so] + + if "connection_timeout" in conn_extra: + conn_config["connection_timeout"] = float(conn_extra["connection_timeout"]) + + if "log_level" in conn_extra: + import logging + + log_lvl = conn_extra["log_level"] + conn_config["log_path"] = None + if isinstance(log_lvl, str): + log_lvl = log_lvl.lower() + if log_lvl == "critical": + conn_config["log_level"] = logging.CRITICAL + elif log_lvl == "error": + conn_config["log_level"] = logging.ERROR + elif log_lvl == "warning": + conn_config["log_level"] = logging.WARNING + elif log_lvl == "info": + conn_config["log_level"] = logging.INFO + elif log_lvl == "debug": + conn_config["log_level"] = logging.DEBUG + elif log_lvl == "notset": + conn_config["log_level"] = logging.NOTSET + else: + conn_config["log_level"] = int(conn_extra["log_level"]) + conn = connect(**conn_config) return conn diff --git a/docs/apache-airflow-providers-vertica/connections/vertica.rst b/docs/apache-airflow-providers-vertica/connections/vertica.rst new file mode 100644 index 0000000000000..86f583a548124 --- /dev/null +++ b/docs/apache-airflow-providers-vertica/connections/vertica.rst @@ -0,0 +1,83 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + + +.. _howto/connection:vertica: + +Vertica Connection +================== +The Vertica connection type provides connection to a Vertica database. + +Configuring the Connection +-------------------------- +Host (required) + The host to connect to. + +Schema (optional) + Specify the schema name to be used in the database. + +Login (required) + Specify the user name to connect. + +Password (required) + Specify the password to connect. + +Extra (optional) + Specify the extra parameters (as json dictionary) that can be used in Vertica + connection. + + The following extras are supported: + + * ``backup_server_node``: See `Connection Failover `_. + * ``binary_transfer``: See `Data Transfer Format `_. + * ``connection_load_balance``: See `Connection Load Balancing `_. + * ``connection_timeout``: The number of seconds (can be a nonnegative floating point number) the client + waits for a socket operation (Establishing a TCP connection or read/write operation). + * ``disable_copy_local``: See `COPY FROM LOCAL `_. + * ``kerberos_host_name``: See `Kerberos Authentication `_. + * ``kerberos_service_name``: See `Kerberos Authentication `_. + * ``log_level``: Enable vertica client logging. Traces will be visible in tasks log. See `Logging `_. + * ``request_complex_types:``: See `SQL Data conversion to Python objects `_. + * ``session_label``: Sets a label for the connection on the server. + * ``ssl``: Support only True or False. See `TLS/SSL `_. + * ``unicode_error``: See `UTF-8 encoding issues `_. + * ``use_prepared_statements``: See `Passing parameters to SQL queries `_. + * ``workload``: Sets the workload name associated with this session. + + See `vertica-python docs `_ for details. + + + Example "extras" field: + + .. code-block:: json + + { + "connection_load_balance": true, + "log_level": "error", + "ssl": true + } + + or + + .. code-block:: json + + { + "session_label": "airflow-session", + "connection_timeout": 30, + "backup_server_node": ["bck_server_1", "bck_server_2"] + } diff --git a/docs/apache-airflow-providers-vertica/index.rst b/docs/apache-airflow-providers-vertica/index.rst index db09f1924c4d7..ae7c2457b8117 100644 --- a/docs/apache-airflow-providers-vertica/index.rst +++ b/docs/apache-airflow-providers-vertica/index.rst @@ -29,6 +29,13 @@ Changelog Security +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: Guides + + Connection types + .. toctree:: :hidden: :maxdepth: 1 diff --git a/tests/providers/vertica/hooks/test_vertica.py b/tests/providers/vertica/hooks/test_vertica.py index e78c2a0c5c813..146c3bcd1136b 100644 --- a/tests/providers/vertica/hooks/test_vertica.py +++ b/tests/providers/vertica/hooks/test_vertica.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import json from unittest import mock from unittest.mock import patch @@ -47,6 +48,81 @@ def test_get_conn(self, mock_connect): host="host", port=5433, database="vertica", user="login", password="password" ) + @patch("airflow.providers.vertica.hooks.vertica.connect") + def test_get_conn_extra_parameters_no_cast(self, mock_connect): + """Test if parameters are correctly passed to connection""" + extra_dict = self.connection.extra_dejson + bool_options = [ + "connection_load_balance", + "binary_transfer", + "disable_copy_local", + "use_prepared_statements", + ] + for bo in bool_options: + extra_dict.update({bo: True}) + extra_dict.update({"request_complex_types": False}) + + std_options = [ + "session_label", + "kerberos_host_name", + "kerberos_service_name", + "unicode_error", + "workload", + "ssl", + ] + for so in std_options: + extra_dict.update({so: so}) + bck_server_node = ["1.2.3.4", "4.3.2.1"] + conn_timeout = 30 + log_lvl = 40 + extra_dict.update({"backup_server_node": bck_server_node}) + extra_dict.update({"connection_timeout": conn_timeout}) + extra_dict.update({"log_level": log_lvl}) + self.connection.extra = json.dumps(extra_dict) + self.db_hook.get_conn() + assert mock_connect.call_count == 1 + args, kwargs = mock_connect.call_args + assert args == () + for bo in bool_options: + assert kwargs[bo] is True + assert kwargs["request_complex_types"] is False + for so in std_options: + assert kwargs[so] == so + assert bck_server_node[0] in kwargs["backup_server_node"] + assert bck_server_node[1] in kwargs["backup_server_node"] + assert kwargs["connection_timeout"] == conn_timeout + assert kwargs["log_level"] == log_lvl + assert kwargs["log_path"] is None + + @patch("airflow.providers.vertica.hooks.vertica.connect") + def test_get_conn_extra_parameters_cast(self, mock_connect): + """Test if parameters that can be passed either as string or int/bool + like log_level are correctly converted when passed as string + (while test_get_conn_extra_parameters_no_cast tests them passed as int/bool)""" + import logging + + extra_dict = self.connection.extra_dejson + bool_options = [ + "connection_load_balance", + "binary_transfer", + "disable_copy_local", + "use_prepared_statements", + ] + for bo in bool_options: + extra_dict.update({bo: "True"}) + extra_dict.update({"request_complex_types": "False"}) + extra_dict.update({"log_level": "Error"}) + self.connection.extra = json.dumps(extra_dict) + self.db_hook.get_conn() + assert mock_connect.call_count == 1 + args, kwargs = mock_connect.call_args + assert args == () + for bo in bool_options: + assert kwargs[bo] is True + assert kwargs["request_complex_types"] is False + assert kwargs["log_level"] == logging.ERROR + assert kwargs["log_path"] is None + class TestVerticaHook: def setup_method(self):