diff --git a/README.rst b/README.rst index 89c54532..5afd746c 100644 --- a/README.rst +++ b/README.rst @@ -71,9 +71,11 @@ First install this package to register it with SQLAlchemy (see ``setup.py``). # Presto engine = create_engine('presto://localhost:8080/hive/default') # Trino - engine = create_engine('trino://localhost:8080/hive/default') + engine = create_engine('trino+pyhive://localhost:8080/hive/default') # Hive engine = create_engine('hive://localhost:10000/default') + + # SQLAlchemy < 2.0 logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True) print select([func.count('*')], from_obj=logs).scalar() @@ -82,6 +84,20 @@ First install this package to register it with SQLAlchemy (see ``setup.py``). logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True) print select([func.count('*')], from_obj=logs).scalar() + # SQLAlchemy >= 2.0 + metadata_obj = MetaData() + books = Table("books", metadata_obj, Column("id", Integer), Column("title", String), Column("primary_author", String)) + metadata_obj.create_all(engine) + inspector = inspect(engine) + inspector.get_columns('books') + + with engine.connect() as con: + data = [{ "id": 1, "title": "The Hobbit", "primary_author": "Tolkien" }, + { "id": 2, "title": "The Silmarillion", "primary_author": "Tolkien" }] + con.execute(books.insert(), data[0]) + result = con.execute(text("select * from books")) + print(result.fetchall()) + Note: query generation functionality is not exhaustive or fully tested, but there should be no problem with raw SQL. diff --git a/pyhive/sqlalchemy_hive.py b/pyhive/sqlalchemy_hive.py index f39f1793..e2244525 100644 --- a/pyhive/sqlalchemy_hive.py +++ b/pyhive/sqlalchemy_hive.py @@ -13,11 +13,22 @@ import re from sqlalchemy import exc -from sqlalchemy import processors +from sqlalchemy.sql import text +try: + from sqlalchemy import processors +except ImportError: + # Required for SQLAlchemy>=2.0 + from sqlalchemy.engine import processors from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -from sqlalchemy.databases import mysql +try: + from sqlalchemy.databases import mysql + mysql_tinyinteger = mysql.MSTinyInteger +except ImportError: + # Required for SQLAlchemy>2.0 + from sqlalchemy.dialects import mysql + mysql_tinyinteger = mysql.base.MSTinyInteger from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -121,7 +132,7 @@ def __init__(self, dialect): _type_map = { 'boolean': types.Boolean, - 'tinyint': mysql.MSTinyInteger, + 'tinyint': mysql_tinyinteger, 'smallint': types.SmallInteger, 'int': types.Integer, 'bigint': types.BigInteger, @@ -247,10 +258,15 @@ class HiveDialect(default.DefaultDialect): supports_multivalues_insert = True type_compiler = HiveTypeCompiler supports_sane_rowcount = False + supports_statement_cache = False @classmethod def dbapi(cls): return hive + + @classmethod + def import_dbapi(cls): + return hive def create_connect_args(self, url): kwargs = { @@ -265,7 +281,7 @@ def create_connect_args(self, url): def get_schema_names(self, connection, **kw): # Equivalent to SHOW DATABASES - return [row[0] for row in connection.execute('SHOW SCHEMAS')] + return [row[0] for row in connection.execute(text('SHOW SCHEMAS'))] def get_view_names(self, connection, schema=None, **kw): # Hive does not provide functionality to query tableType @@ -280,7 +296,7 @@ def _get_table_columns(self, connection, table_name, schema): # Using DESCRIBE works but is uglier. try: # This needs the table name to be unescaped (no backticks). - rows = connection.execute('DESCRIBE {}'.format(full_table)).fetchall() + rows = connection.execute(text('DESCRIBE {}'.format(full_table))).fetchall() except exc.OperationalError as e: # Does the table exist? regex_fmt = r'TExecuteStatementResp.*SemanticException.*Table not found {}' @@ -296,7 +312,7 @@ def _get_table_columns(self, connection, table_name, schema): raise exc.NoSuchTableError(full_table) return rows - def has_table(self, connection, table_name, schema=None): + def has_table(self, connection, table_name, schema=None, **kw): try: self._get_table_columns(connection, table_name, schema) return True @@ -361,7 +377,7 @@ def get_table_names(self, connection, schema=None, **kw): query = 'SHOW TABLES' if schema: query += ' IN ' + self.identifier_preparer.quote_identifier(schema) - return [row[0] for row in connection.execute(query)] + return [row[0] for row in connection.execute(text(query))] def do_rollback(self, dbapi_connection): # No transactions for Hive diff --git a/pyhive/sqlalchemy_presto.py b/pyhive/sqlalchemy_presto.py index a199ebe1..bfe1ba04 100644 --- a/pyhive/sqlalchemy_presto.py +++ b/pyhive/sqlalchemy_presto.py @@ -9,11 +9,19 @@ from __future__ import unicode_literals import re +import sqlalchemy from sqlalchemy import exc from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -from sqlalchemy.databases import mysql +from sqlalchemy.sql import text +try: + from sqlalchemy.databases import mysql + mysql_tinyinteger = mysql.MSTinyInteger +except ImportError: + # Required for SQLAlchemy>=2.0 + from sqlalchemy.dialects import mysql + mysql_tinyinteger = mysql.base.MSTinyInteger from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -21,6 +29,7 @@ from pyhive import presto from pyhive.common import UniversalSet +sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) class PrestoIdentifierPreparer(compiler.IdentifierPreparer): # Just quote everything to make things simpler / easier to upgrade @@ -29,7 +38,7 @@ class PrestoIdentifierPreparer(compiler.IdentifierPreparer): _type_map = { 'boolean': types.Boolean, - 'tinyint': mysql.MSTinyInteger, + 'tinyint': mysql_tinyinteger, 'smallint': types.SmallInteger, 'integer': types.Integer, 'bigint': types.BigInteger, @@ -80,6 +89,7 @@ class PrestoDialect(default.DefaultDialect): supports_multivalues_insert = True supports_unicode_statements = True supports_unicode_binds = True + supports_statement_cache = False returns_unicode_strings = True description_encoding = None supports_native_boolean = True @@ -88,6 +98,10 @@ class PrestoDialect(default.DefaultDialect): @classmethod def dbapi(cls): return presto + + @classmethod + def import_dbapi(cls): + return presto def create_connect_args(self, url): db_parts = (url.database or 'hive').split('/') @@ -108,14 +122,14 @@ def create_connect_args(self, url): return [], kwargs def get_schema_names(self, connection, **kw): - return [row.Schema for row in connection.execute('SHOW SCHEMAS')] + return [row.Schema for row in connection.execute(text('SHOW SCHEMAS'))] def _get_table_columns(self, connection, table_name, schema): full_table = self.identifier_preparer.quote_identifier(table_name) if schema: full_table = self.identifier_preparer.quote_identifier(schema) + '.' + full_table try: - return connection.execute('SHOW COLUMNS FROM {}'.format(full_table)) + return connection.execute(text('SHOW COLUMNS FROM {}'.format(full_table))) except (presto.DatabaseError, exc.DatabaseError) as e: # Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which # it successfully does in the Hive version. The difference with Presto is that this @@ -134,7 +148,7 @@ def _get_table_columns(self, connection, table_name, schema): else: raise - def has_table(self, connection, table_name, schema=None): + def has_table(self, connection, table_name, schema=None, **kw): try: self._get_table_columns(connection, table_name, schema) return True @@ -176,6 +190,8 @@ def get_indexes(self, connection, table_name, schema=None, **kw): # - a boolean column named "Partition Key" # - a string in the "Comment" column # - a string in the "Extra" column + if sqlalchemy_version >= 1.4: + row = row._mapping is_partition_key = ( (part_key in row and row[part_key]) or row['Comment'].startswith(part_key) @@ -192,7 +208,7 @@ def get_table_names(self, connection, schema=None, **kw): query = 'SHOW TABLES' if schema: query += ' FROM ' + self.identifier_preparer.quote_identifier(schema) - return [row.Table for row in connection.execute(query)] + return [row.Table for row in connection.execute(text(query))] def do_rollback(self, dbapi_connection): # No transactions for Presto diff --git a/pyhive/sqlalchemy_trino.py b/pyhive/sqlalchemy_trino.py index 4b2b3698..11be2a6c 100644 --- a/pyhive/sqlalchemy_trino.py +++ b/pyhive/sqlalchemy_trino.py @@ -13,7 +13,13 @@ from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -from sqlalchemy.databases import mysql +try: + from sqlalchemy.databases import mysql + mysql_tinyinteger = mysql.MSTinyInteger +except ImportError: + # Required for SQLAlchemy>=2.0 + from sqlalchemy.dialects import mysql + mysql_tinyinteger = mysql.base.MSTinyInteger from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -28,7 +34,7 @@ class TrinoIdentifierPreparer(PrestoIdentifierPreparer): _type_map = { 'boolean': types.Boolean, - 'tinyint': mysql.MSTinyInteger, + 'tinyint': mysql_tinyinteger, 'smallint': types.SmallInteger, 'integer': types.Integer, 'bigint': types.BigInteger, @@ -67,7 +73,12 @@ def visit_TEXT(self, type_, **kw): class TrinoDialect(PrestoDialect): name = 'trino' + supports_statement_cache = False @classmethod def dbapi(cls): return trino + + @classmethod + def import_dbapi(cls): + return trino diff --git a/pyhive/tests/sqlalchemy_test_case.py b/pyhive/tests/sqlalchemy_test_case.py index 652e05f4..db89d57b 100644 --- a/pyhive/tests/sqlalchemy_test_case.py +++ b/pyhive/tests/sqlalchemy_test_case.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals import abc +import re import contextlib import functools @@ -14,8 +15,10 @@ from sqlalchemy.schema import Index from sqlalchemy.schema import MetaData from sqlalchemy.schema import Table -from sqlalchemy.sql import expression +from sqlalchemy.sql import expression, text +from sqlalchemy import String +sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) def with_engine_connection(fn): """Pass a connection to the given function and handle cleanup. @@ -32,19 +35,33 @@ def wrapped_fn(self, *args, **kwargs): engine.dispose() return wrapped_fn +def reflect_table(engine, connection, table, include_columns, exclude_columns, resolve_fks): + if sqlalchemy_version >= 1.4: + insp = sqlalchemy.inspect(engine) + insp.reflect_table( + table, + include_columns=include_columns, + exclude_columns=exclude_columns, + resolve_fks=resolve_fks, + ) + else: + engine.dialect.reflecttable( + connection, table, include_columns=include_columns, + exclude_columns=exclude_columns, resolve_fks=resolve_fks) + class SqlAlchemyTestCase(with_metaclass(abc.ABCMeta, object)): @with_engine_connection def test_basic_query(self, engine, connection): - rows = connection.execute('SELECT * FROM one_row').fetchall() + rows = connection.execute(text('SELECT * FROM one_row')).fetchall() self.assertEqual(len(rows), 1) self.assertEqual(rows[0].number_of_rows, 1) # number_of_rows is the column name self.assertEqual(len(rows[0]), 1) @with_engine_connection def test_one_row_complex_null(self, engine, connection): - one_row_complex_null = Table('one_row_complex_null', MetaData(bind=engine), autoload=True) - rows = one_row_complex_null.select().execute().fetchall() + one_row_complex_null = Table('one_row_complex_null', MetaData(), autoload_with=engine) + rows = connection.execute(one_row_complex_null.select()).fetchall() self.assertEqual(len(rows), 1) self.assertEqual(list(rows[0]), [None] * len(rows[0])) @@ -53,27 +70,26 @@ def test_reflect_no_such_table(self, engine, connection): """reflecttable should throw an exception on an invalid table""" self.assertRaises( NoSuchTableError, - lambda: Table('this_does_not_exist', MetaData(bind=engine), autoload=True)) + lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine)) self.assertRaises( NoSuchTableError, - lambda: Table('this_does_not_exist', MetaData(bind=engine), - schema='also_does_not_exist', autoload=True)) + lambda: Table('this_does_not_exist', MetaData(schema='also_does_not_exist'), autoload_with=engine)) @with_engine_connection def test_reflect_include_columns(self, engine, connection): """When passed include_columns, reflecttable should filter out other columns""" - one_row_complex = Table('one_row_complex', MetaData(bind=engine)) - engine.dialect.reflecttable( - connection, one_row_complex, include_columns=['int'], + + one_row_complex = Table('one_row_complex', MetaData()) + reflect_table(engine, connection, one_row_complex, include_columns=['int'], exclude_columns=[], resolve_fks=True) + self.assertEqual(len(one_row_complex.c), 1) self.assertIsNotNone(one_row_complex.c.int) self.assertRaises(AttributeError, lambda: one_row_complex.c.tinyint) @with_engine_connection def test_reflect_with_schema(self, engine, connection): - dummy = Table('dummy_table', MetaData(bind=engine), schema='pyhive_test_database', - autoload=True) + dummy = Table('dummy_table', MetaData(schema='pyhive_test_database'), autoload_with=engine) self.assertEqual(len(dummy.c), 1) self.assertIsNotNone(dummy.c.a) @@ -81,22 +97,22 @@ def test_reflect_with_schema(self, engine, connection): @with_engine_connection def test_reflect_partitions(self, engine, connection): """reflecttable should get the partition column as an index""" - many_rows = Table('many_rows', MetaData(bind=engine), autoload=True) + many_rows = Table('many_rows', MetaData(), autoload_with=engine) self.assertEqual(len(many_rows.c), 2) self.assertEqual(repr(many_rows.indexes), repr({Index('partition', many_rows.c.b)})) - many_rows = Table('many_rows', MetaData(bind=engine)) - engine.dialect.reflecttable( - connection, many_rows, include_columns=['a'], + many_rows = Table('many_rows', MetaData()) + reflect_table(engine, connection, many_rows, include_columns=['a'], exclude_columns=[], resolve_fks=True) + self.assertEqual(len(many_rows.c), 1) self.assertFalse(many_rows.c.a.index) self.assertFalse(many_rows.indexes) - many_rows = Table('many_rows', MetaData(bind=engine)) - engine.dialect.reflecttable( - connection, many_rows, include_columns=['b'], + many_rows = Table('many_rows', MetaData()) + reflect_table(engine, connection, many_rows, include_columns=['b'], exclude_columns=[], resolve_fks=True) + self.assertEqual(len(many_rows.c), 1) self.assertEqual(repr(many_rows.indexes), repr({Index('partition', many_rows.c.b)})) @@ -104,11 +120,15 @@ def test_reflect_partitions(self, engine, connection): def test_unicode(self, engine, connection): """Verify that unicode strings make it through SQLAlchemy and the backend""" unicode_str = "中文" - one_row = Table('one_row', MetaData(bind=engine)) - returned_str = sqlalchemy.select( - [expression.bindparam("好", unicode_str)], - from_obj=one_row, - ).scalar() + one_row = Table('one_row', MetaData()) + + if sqlalchemy_version >= 1.4: + returned_str = connection.execute(sqlalchemy.select( + expression.bindparam("好", unicode_str, type_=String())).select_from(one_row)).scalar() + else: + returned_str = connection.execute(sqlalchemy.select([ + expression.bindparam("好", unicode_str, type_=String())]).select_from(one_row)).scalar() + self.assertEqual(returned_str, unicode_str) @with_engine_connection @@ -133,13 +153,21 @@ def test_get_table_names(self, engine, connection): @with_engine_connection def test_has_table(self, engine, connection): - self.assertTrue(Table('one_row', MetaData(bind=engine)).exists()) - self.assertFalse(Table('this_table_does_not_exist', MetaData(bind=engine)).exists()) + if sqlalchemy_version >= 1.4: + insp = sqlalchemy.inspect(engine) + self.assertTrue(insp.has_table("one_row")) + self.assertFalse(insp.has_table("this_table_does_not_exist")) + else: + self.assertTrue(Table('one_row', MetaData(bind=engine)).exists()) + self.assertFalse(Table('this_table_does_not_exist', MetaData(bind=engine)).exists()) @with_engine_connection def test_char_length(self, engine, connection): - one_row_complex = Table('one_row_complex', MetaData(bind=engine), autoload=True) - result = sqlalchemy.select([ - sqlalchemy.func.char_length(one_row_complex.c.string) - ]).execute().scalar() + one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) + + if sqlalchemy_version >= 1.4: + result = connection.execute(sqlalchemy.select(sqlalchemy.func.char_length(one_row_complex.c.string))).scalar() + else: + result = connection.execute(sqlalchemy.select([sqlalchemy.func.char_length(one_row_complex.c.string)])).scalar() + self.assertEqual(result, len('a string')) diff --git a/pyhive/tests/test_sqlalchemy_hive.py b/pyhive/tests/test_sqlalchemy_hive.py index 1ff0e817..790bec4c 100644 --- a/pyhive/tests/test_sqlalchemy_hive.py +++ b/pyhive/tests/test_sqlalchemy_hive.py @@ -4,6 +4,7 @@ from pyhive.sqlalchemy_hive import HiveDate from pyhive.sqlalchemy_hive import HiveDecimal from pyhive.sqlalchemy_hive import HiveTimestamp +from sqlalchemy.exc import NoSuchTableError, OperationalError from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase from pyhive.tests.sqlalchemy_test_case import with_engine_connection from sqlalchemy import types @@ -11,11 +12,15 @@ from sqlalchemy.schema import Column from sqlalchemy.schema import MetaData from sqlalchemy.schema import Table +from sqlalchemy.sql import text import contextlib import datetime import decimal import sqlalchemy.types import unittest +import re + +sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) _ONE_ROW_COMPLEX_CONTENTS = [ True, @@ -64,7 +69,11 @@ def test_dotted_column_names(self, engine, connection): """When Hive returns a dotted column name, both the non-dotted version should be available as an attribute, and the dotted version should remain available as a key. """ - row = connection.execute('SELECT * FROM one_row').fetchone() + row = connection.execute(text('SELECT * FROM one_row')).fetchone() + + if sqlalchemy_version >= 1.4: + row = row._mapping + assert row.keys() == ['number_of_rows'] assert 'number_of_rows' in row assert row.number_of_rows == 1 @@ -76,20 +85,33 @@ def test_dotted_column_names(self, engine, connection): def test_dotted_column_names_raw(self, engine, connection): """When Hive returns a dotted column name, and raw mode is on, nothing should be modified. """ - row = connection.execution_options(hive_raw_colnames=True) \ - .execute('SELECT * FROM one_row').fetchone() + row = connection.execution_options(hive_raw_colnames=True).execute(text('SELECT * FROM one_row')).fetchone() + + if sqlalchemy_version >= 1.4: + row = row._mapping + assert row.keys() == ['one_row.number_of_rows'] assert 'number_of_rows' not in row assert getattr(row, 'one_row.number_of_rows') == 1 assert row['one_row.number_of_rows'] == 1 + @with_engine_connection + def test_reflect_no_such_table(self, engine, connection): + """reflecttable should throw an exception on an invalid table""" + self.assertRaises( + NoSuchTableError, + lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine)) + self.assertRaises( + OperationalError, + lambda: Table('this_does_not_exist', MetaData(schema="also_does_not_exist"), autoload_with=engine)) + @with_engine_connection def test_reflect_select(self, engine, connection): """reflecttable should be able to fill in a table from the name""" - one_row_complex = Table('one_row_complex', MetaData(bind=engine), autoload=True) + one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) self.assertEqual(len(one_row_complex.c), 15) self.assertIsInstance(one_row_complex.c.string, Column) - row = one_row_complex.select().execute().fetchone() + row = connection.execute(one_row_complex.select()).fetchone() self.assertEqual(list(row), _ONE_ROW_COMPLEX_CONTENTS) # TODO some of these types could be filled in better @@ -112,15 +134,15 @@ def test_reflect_select(self, engine, connection): @with_engine_connection def test_type_map(self, engine, connection): """sqlalchemy should use the dbapi_type_map to infer types from raw queries""" - row = connection.execute('SELECT * FROM one_row_complex').fetchone() + row = connection.execute(text('SELECT * FROM one_row_complex')).fetchone() self.assertListEqual(list(row), _ONE_ROW_COMPLEX_CONTENTS) @with_engine_connection def test_reserved_words(self, engine, connection): """Hive uses backticks""" # Use keywords for the table/column name - fake_table = Table('select', MetaData(bind=engine), Column('map', sqlalchemy.types.String)) - query = str(fake_table.select(fake_table.c.map == 'a')) + fake_table = Table('select', MetaData(), Column('map', sqlalchemy.types.String)) + query = str(fake_table.select().where(fake_table.c.map == 'a').compile(engine)) self.assertIn('`select`', query) self.assertIn('`map`', query) self.assertNotIn('"select"', query) @@ -132,12 +154,12 @@ def test_switch_database(self): with contextlib.closing(engine.connect()) as connection: self.assertIn( ('dummy_table',), - connection.execute('SHOW TABLES').fetchall() + connection.execute(text('SHOW TABLES')).fetchall() ) - connection.execute('USE default') + connection.execute(text('USE default')) self.assertIn( ('one_row',), - connection.execute('SHOW TABLES').fetchall() + connection.execute(text('SHOW TABLES')).fetchall() ) finally: engine.dispose() @@ -160,13 +182,13 @@ def test_lots_of_types(self, engine, connection): cols.append(Column('hive_date', HiveDate)) cols.append(Column('hive_decimal', HiveDecimal)) cols.append(Column('hive_timestamp', HiveTimestamp)) - table = Table('test_table', MetaData(bind=engine), *cols, schema='pyhive_test_database') - table.drop(checkfirst=True) - table.create() - connection.execute('SET mapred.job.tracker=local') - connection.execute('USE pyhive_test_database') + table = Table('test_table', MetaData(schema='pyhive_test_database'), *cols,) + table.drop(checkfirst=True, bind=connection) + table.create(bind=connection) + connection.execute(text('SET mapred.job.tracker=local')) + connection.execute(text('USE pyhive_test_database')) big_number = 10 ** 10 - 1 - connection.execute(""" + connection.execute(text(""" INSERT OVERWRITE TABLE test_table SELECT 1, "a", "a", "a", "a", "a", 0.1, @@ -175,41 +197,39 @@ def test_lots_of_types(self, engine, connection): "a", 1, 1, 0.1, 0.1, 0, 0, 0, "a", false, "a", "a", - 0, %d, 123 + 2000 + 0, :big_number, 123 + 2000 FROM default.one_row - """, big_number) - row = connection.execute(table.select()).fetchone() - self.assertEqual(row.hive_date, datetime.date(1970, 1, 1)) + """), {"big_number": big_number}) + row = connection.execute(text("select * from test_table")).fetchone() + self.assertEqual(row.hive_date, datetime.datetime(1970, 1, 1, 0, 0)) self.assertEqual(row.hive_decimal, decimal.Decimal(big_number)) self.assertEqual(row.hive_timestamp, datetime.datetime(1970, 1, 1, 0, 0, 2, 123000)) - table.drop() + table.drop(bind=connection) @with_engine_connection def test_insert_select(self, engine, connection): - one_row = Table('one_row', MetaData(bind=engine), autoload=True) - table = Table('insert_test', MetaData(bind=engine), - Column('a', sqlalchemy.types.Integer), - schema='pyhive_test_database') - table.drop(checkfirst=True) - table.create() - connection.execute('SET mapred.job.tracker=local') + one_row = Table('one_row', MetaData(), autoload_with=engine) + table = Table('insert_test', MetaData(schema='pyhive_test_database'), + Column('a', sqlalchemy.types.Integer)) + table.drop(checkfirst=True, bind=connection) + table.create(bind=connection) + connection.execute(text('SET mapred.job.tracker=local')) # NOTE(jing) I'm stuck on a version of Hive without INSERT ... VALUES connection.execute(table.insert().from_select(['a'], one_row.select())) - - result = table.select().execute().fetchall() + + result = connection.execute(table.select()).fetchall() expected = [(1,)] self.assertEqual(result, expected) @with_engine_connection def test_insert_values(self, engine, connection): - table = Table('insert_test', MetaData(bind=engine), - Column('a', sqlalchemy.types.Integer), - schema='pyhive_test_database') - table.drop(checkfirst=True) - table.create() - connection.execute(table.insert([{'a': 1}, {'a': 2}])) - - result = table.select().execute().fetchall() + table = Table('insert_test', MetaData(schema='pyhive_test_database'), + Column('a', sqlalchemy.types.Integer),) + table.drop(checkfirst=True, bind=connection) + table.create(bind=connection) + connection.execute(table.insert().values([{'a': 1}, {'a': 2}])) + + result = connection.execute(table.select()).fetchall() expected = [(1,), (2,)] self.assertEqual(result, expected) diff --git a/pyhive/tests/test_sqlalchemy_presto.py b/pyhive/tests/test_sqlalchemy_presto.py index a01e4a35..58a5c034 100644 --- a/pyhive/tests/test_sqlalchemy_presto.py +++ b/pyhive/tests/test_sqlalchemy_presto.py @@ -8,7 +8,9 @@ from sqlalchemy.schema import Column from sqlalchemy.schema import MetaData from sqlalchemy.schema import Table +from sqlalchemy.sql import text from sqlalchemy.types import String +from decimal import Decimal import contextlib import unittest @@ -27,11 +29,11 @@ def test_bad_format(self): @with_engine_connection def test_reflect_select(self, engine, connection): """reflecttable should be able to fill in a table from the name""" - one_row_complex = Table('one_row_complex', MetaData(bind=engine), autoload=True) + one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) # Presto ignores the union column self.assertEqual(len(one_row_complex.c), 15 - 1) self.assertIsInstance(one_row_complex.c.string, Column) - rows = one_row_complex.select().execute().fetchall() + rows = connection.execute(one_row_complex.select()).fetchall() self.assertEqual(len(rows), 1) self.assertEqual(list(rows[0]), [ True, @@ -48,7 +50,7 @@ def test_reflect_select(self, engine, connection): {"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON [1, 2], # struct is returned as a list of elements # '{0:1}', - '0.1', + Decimal('0.1'), ]) # TODO some of these types could be filled in better @@ -71,7 +73,7 @@ def test_url_default(self): engine = create_engine('presto://localhost:8080/hive') try: with contextlib.closing(engine.connect()) as connection: - self.assertEqual(connection.execute('SELECT 1 AS foobar FROM one_row').scalar(), 1) + self.assertEqual(connection.execute(text('SELECT 1 AS foobar FROM one_row')).scalar(), 1) finally: engine.dispose() @@ -79,8 +81,8 @@ def test_url_default(self): def test_reserved_words(self, engine, connection): """Presto uses double quotes, not backticks""" # Use keywords for the table/column name - fake_table = Table('select', MetaData(bind=engine), Column('current_timestamp', String)) - query = str(fake_table.select(fake_table.c.current_timestamp == 'a')) + fake_table = Table('select', MetaData(), Column('current_timestamp', String)) + query = str(fake_table.select().where(fake_table.c.current_timestamp == 'a').compile(engine)) self.assertIn('"select"', query) self.assertIn('"current_timestamp"', query) self.assertNotIn('`select`', query) diff --git a/pyhive/tests/test_sqlalchemy_trino.py b/pyhive/tests/test_sqlalchemy_trino.py new file mode 100644 index 00000000..c929f941 --- /dev/null +++ b/pyhive/tests/test_sqlalchemy_trino.py @@ -0,0 +1,93 @@ +from sqlalchemy.engine import create_engine +from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase +from pyhive.tests.sqlalchemy_test_case import with_engine_connection +from sqlalchemy.exc import NoSuchTableError, DatabaseError +from sqlalchemy.schema import MetaData, Table, Column +from sqlalchemy.types import String +from sqlalchemy.sql import text +from sqlalchemy import types +from decimal import Decimal + +import unittest +import contextlib + + +class TestSqlAlchemyTrino(unittest.TestCase, SqlAlchemyTestCase): + def create_engine(self): + return create_engine('trino+pyhive://localhost:18080/hive/default?source={}'.format(self.id())) + + def test_bad_format(self): + self.assertRaises( + ValueError, + lambda: create_engine('trino+pyhive://localhost:18080/hive/default/what'), + ) + + @with_engine_connection + def test_reflect_select(self, engine, connection): + """reflecttable should be able to fill in a table from the name""" + one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) + # Presto ignores the union column + self.assertEqual(len(one_row_complex.c), 15 - 1) + self.assertIsInstance(one_row_complex.c.string, Column) + rows = connection.execute(one_row_complex.select()).fetchall() + self.assertEqual(len(rows), 1) + self.assertEqual(list(rows[0]), [ + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + 'a string', + '1970-01-01 00:00:00.000', + b'123', + [1, 2], + {"1": 2, "3": 4}, + [1, 2], + Decimal('0.1'), + ]) + + self.assertIsInstance(one_row_complex.c.boolean.type, types.Boolean) + self.assertIsInstance(one_row_complex.c.tinyint.type, types.Integer) + self.assertIsInstance(one_row_complex.c.smallint.type, types.Integer) + self.assertIsInstance(one_row_complex.c.int.type, types.Integer) + self.assertIsInstance(one_row_complex.c.bigint.type, types.BigInteger) + self.assertIsInstance(one_row_complex.c.float.type, types.Float) + self.assertIsInstance(one_row_complex.c.double.type, types.Float) + self.assertIsInstance(one_row_complex.c.string.type, String) + self.assertIsInstance(one_row_complex.c.timestamp.type, types.NullType) + self.assertIsInstance(one_row_complex.c.binary.type, types.VARBINARY) + self.assertIsInstance(one_row_complex.c.array.type, types.NullType) + self.assertIsInstance(one_row_complex.c.map.type, types.NullType) + self.assertIsInstance(one_row_complex.c.struct.type, types.NullType) + self.assertIsInstance(one_row_complex.c.decimal.type, types.NullType) + + @with_engine_connection + def test_reflect_no_such_table(self, engine, connection): + """reflecttable should throw an exception on an invalid table""" + self.assertRaises( + NoSuchTableError, + lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine)) + self.assertRaises( + DatabaseError, + lambda: Table('this_does_not_exist', MetaData(schema="also_does_not_exist"), autoload_with=engine)) + + def test_url_default(self): + engine = create_engine('trino+pyhive://localhost:18080/hive') + try: + with contextlib.closing(engine.connect()) as connection: + self.assertEqual(connection.execute(text('SELECT 1 AS foobar FROM one_row')).scalar(), 1) + finally: + engine.dispose() + + @with_engine_connection + def test_reserved_words(self, engine, connection): + """Trino uses double quotes, not backticks""" + # Use keywords for the table/column name + fake_table = Table('select', MetaData(), Column('current_timestamp', String)) + query = str(fake_table.select().where(fake_table.c.current_timestamp == 'a').compile(engine)) + self.assertIn('"select"', query) + self.assertIn('"current_timestamp"', query) + self.assertNotIn('`select`', query) + self.assertNotIn('`current_timestamp`', query)