diff --git a/pyhive/hive.py b/pyhive/hive.py index a8635bac..66569406 100644 --- a/pyhive/hive.py +++ b/pyhive/hive.py @@ -8,9 +8,12 @@ from __future__ import absolute_import from __future__ import unicode_literals +import base64 import datetime import re from decimal import Decimal +from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED, create_default_context + from TCLIService import TCLIService from TCLIService import constants @@ -25,6 +28,7 @@ import getpass import logging import sys +import thrift.transport.THttpClient import thrift.protocol.TBinaryProtocol import thrift.transport.TSocket import thrift.transport.TTransport @@ -38,6 +42,12 @@ _TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)') +ssl_cert_parameter_map = { + "none": CERT_NONE, + "optional": CERT_OPTIONAL, + "required": CERT_REQUIRED, +} + def _parse_timestamp(value): if value: @@ -97,9 +107,21 @@ def connect(*args, **kwargs): class Connection(object): """Wraps a Thrift session""" - def __init__(self, host=None, port=None, username=None, database='default', auth=None, - configuration=None, kerberos_service_name=None, password=None, - thrift_transport=None): + def __init__( + self, + host=None, + port=None, + scheme=None, + username=None, + database='default', + auth=None, + configuration=None, + kerberos_service_name=None, + password=None, + check_hostname=None, + ssl_cert=None, + thrift_transport=None + ): """Connect to HiveServer2 :param host: What host HiveServer2 runs on @@ -116,6 +138,32 @@ def __init__(self, host=None, port=None, username=None, database='default', auth https://github.com/cloudera/impyla/blob/255b07ed973d47a3395214ed92d35ec0615ebf62 /impala/_thrift_api.py#L152-L160 """ + if scheme in ("https", "http") and thrift_transport is None: + ssl_context = None + if scheme == "https": + ssl_context = create_default_context() + ssl_context.check_hostname = check_hostname == "true" + ssl_cert = ssl_cert or "none" + ssl_context.verify_mode = ssl_cert_parameter_map.get(ssl_cert, CERT_NONE) + thrift_transport = thrift.transport.THttpClient.THttpClient( + uri_or_host=f"{scheme}://{host}:{port}/cliservice/", + ssl_context=ssl_context, + ) + + if auth in ("BASIC", "NOSASL", "NONE", None): + # Always needs the Authorization header + self._set_authorization_header(thrift_transport, username, password) + elif auth == "KERBEROS" and kerberos_service_name: + self._set_kerberos_header(thrift_transport, kerberos_service_name, host) + else: + raise ValueError( + "Authentication is not valid use one of:" + "BASIC, NOSASL, KERBEROS, NONE" + ) + host, port, auth, kerberos_service_name, password = ( + None, None, None, None, None + ) + username = username or getpass.getuser() configuration = configuration or {} @@ -207,6 +255,31 @@ def sasl_factory(): self._transport.close() raise + @staticmethod + def _set_authorization_header(transport, username=None, password=None): + username = username or "user" + password = password or "pass" + auth_credentials = f"{username}:{password}".encode("UTF-8") + auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode( + "UTF-8" + ) + transport.setCustomHeaders( + {"Authorization": f"Basic {auth_credentials_base64}"} + ) + + @staticmethod + def _set_kerberos_header(transport, kerberos_service_name, host) -> None: + import kerberos + + __, krb_context = kerberos.authGSSClientInit( + service=f"{kerberos_service_name}@{host}" + ) + kerberos.authGSSClientClean(krb_context, "") + kerberos.authGSSClientStep(krb_context, "") + auth_header = kerberos.authGSSClientResponse(krb_context) + + transport.setCustomHeaders({"Authorization": f"Negotiate {auth_header}"}) + def __enter__(self): """Transport should already be opened by __init__""" return self diff --git a/pyhive/sqlalchemy_hive.py b/pyhive/sqlalchemy_hive.py index 59e0c0ee..2ef49652 100644 --- a/pyhive/sqlalchemy_hive.py +++ b/pyhive/sqlalchemy_hive.py @@ -374,3 +374,29 @@ def _check_unicode_returns(self, connection, additional_tests=None): def _check_unicode_description(self, connection): # We decode everything as UTF-8 return True + + +class HiveHTTPDialect(HiveDialect): + + name = "hive" + scheme = "http" + driver = "rest" + + def create_connect_args(self, url): + kwargs = { + "host": url.host, + "port": url.port or 10000, + "scheme": self.scheme, + "username": url.username or None, + "password": url.password or None, + } + if url.query: + kwargs.update(url.query) + return [], kwargs + return ([], kwargs) + + +class HiveHTTPSDialect(HiveHTTPDialect): + + name = "hive" + scheme = "https" diff --git a/setup.py b/setup.py index df410dbc..ad34a38b 100755 --- a/setup.py +++ b/setup.py @@ -66,6 +66,8 @@ def run_tests(self): entry_points={ 'sqlalchemy.dialects': [ 'hive = pyhive.sqlalchemy_hive:HiveDialect', + "hive.http = pyhive.sqlalchemy_hive:HiveHTTPDialect", + "hive.https = pyhive.sqlalchemy_hive:HiveHTTPSDialect", 'presto = pyhive.sqlalchemy_presto:PrestoDialect', 'trino = pyhive.sqlalchemy_trino:TrinoDialect', ],