diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 75cc6c49..57e10c3a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -58,6 +58,8 @@ jobs: - { python: "3.11", trino: "351", sqlalchemy: "~=1.4.0" } # first Trino version # Test with sqlalchemy 1.3 - { python: "3.11", trino: "latest", sqlalchemy: "~=1.3.0" } + # Test with sqlalchemy 2.0 + - { python: "3.11", trino: "latest", sqlalchemy: "~=2.0.0rc1" } env: TRINO_VERSION: "${{ matrix.trino }}" steps: diff --git a/setup.py b/setup.py index c5aa8a1f..9e6003a5 100755 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ version = str(ast.literal_eval(trino_version.group(1))) kerberos_require = ["requests_kerberos"] -sqlalchemy_require = ["sqlalchemy~=1.3"] +sqlalchemy_require = ["sqlalchemy >= 1.3"] external_authentication_token_cache_require = ["keyring"] # We don't add localstorage_require to all_require as users must explicitly opt in to use keyring. diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 9ac2bc02..81f8505e 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -43,10 +43,10 @@ def test_select_query(trino_connection): rows = result.fetchall() assert len(rows) == 25 for row in rows: - assert isinstance(row['nationkey'], int) - assert isinstance(row['name'], str) - assert isinstance(row['regionkey'], int) - assert isinstance(row['comment'], str) + assert isinstance(row.nationkey, int) + assert isinstance(row.name, str) + assert isinstance(row.regionkey, int) + assert isinstance(row.comment, str) def assert_column(table, column_name, column_type): @@ -70,8 +70,8 @@ def test_select_specific_columns(trino_connection): rows = result.fetchall() assert len(rows) > 0 for row in rows: - assert isinstance(row['node_id'], str) - assert isinstance(row['state'], str) + assert isinstance(row.node_id, str) + assert isinstance(row.state, str) @pytest.mark.skipif( @@ -82,7 +82,8 @@ def test_select_specific_columns(trino_connection): def test_define_and_create_table(trino_connection): engine, conn = trino_connection if not engine.dialect.has_schema(conn, "test"): - engine.execute(sqla.schema.CreateSchema("test")) + with engine.begin() as connection: + connection.execute(sqla.schema.CreateSchema("test")) metadata = sqla.MetaData() try: sqla.Table('users', @@ -110,7 +111,8 @@ def test_insert(trino_connection): engine, conn = trino_connection if not engine.dialect.has_schema(conn, "test"): - engine.execute(sqla.schema.CreateSchema("test")) + with engine.begin() as connection: + connection.execute(sqla.schema.CreateSchema("test")) metadata = sqla.MetaData() try: users = sqla.Table('users', @@ -139,7 +141,8 @@ def test_insert(trino_connection): def test_insert_multiple_statements(trino_connection): engine, conn = trino_connection if not engine.dialect.has_schema(conn, "test"): - engine.execute(sqla.schema.CreateSchema("test")) + with engine.begin() as connection: + connection.execute(sqla.schema.CreateSchema("test")) metadata = sqla.MetaData() users = sqla.Table('users', metadata, @@ -180,10 +183,10 @@ def test_operators(trino_connection): rows = result.fetchall() assert len(rows) == 1 for row in rows: - assert isinstance(row['nationkey'], int) - assert isinstance(row['name'], str) - assert isinstance(row['regionkey'], int) - assert isinstance(row['comment'], str) + assert isinstance(row.nationkey, int) + assert isinstance(row.name, str) + assert isinstance(row.regionkey, int) + assert isinstance(row.comment, str) @pytest.mark.skipif( @@ -216,14 +219,14 @@ def test_textual_sql(trino_connection): rows = result.fetchall() assert len(rows) == 3 for row in rows: - assert isinstance(row['custkey'], int) - assert isinstance(row['name'], str) - assert isinstance(row['address'], str) - assert isinstance(row['nationkey'], int) - assert isinstance(row['phone'], str) - assert isinstance(row['acctbal'], float) - assert isinstance(row['mktsegment'], str) - assert isinstance(row['comment'], str) + assert isinstance(row.custkey, int) + assert isinstance(row.name, str) + assert isinstance(row.address, str) + assert isinstance(row.nationkey, int) + assert isinstance(row.phone, str) + assert isinstance(row.acctbal, float) + assert isinstance(row.mktsegment, str) + assert isinstance(row.comment, str) @pytest.mark.skipif( @@ -323,7 +326,8 @@ def test_json_column(trino_connection, json_object): engine, conn = trino_connection if not engine.dialect.has_schema(conn, "test"): - engine.execute(sqla.schema.CreateSchema("test")) + with engine.begin() as connection: + connection.execute(sqla.schema.CreateSchema("test")) metadata = sqla.MetaData() try: @@ -351,7 +355,8 @@ def test_get_table_comment(trino_connection): engine, conn = trino_connection if not engine.dialect.has_schema(conn, "test"): - engine.execute(sqla.schema.CreateSchema("test")) + with engine.begin() as connection: + connection.execute(sqla.schema.CreateSchema("test")) metadata = sqla.MetaData() try: @@ -378,7 +383,8 @@ def test_get_table_names(trino_connection, schema): metadata = sqla.MetaData(schema=schema_name) if not engine.dialect.has_schema(conn, schema_name): - engine.execute(sqla.schema.CreateSchema(schema_name)) + with engine.begin() as connection: + connection.execute(sqla.schema.CreateSchema(schema_name)) try: sqla.Table( @@ -388,10 +394,10 @@ def test_get_table_names(trino_connection, schema): ) metadata.create_all(engine) view_name = schema_name + ".test_view" - conn.execute(f"CREATE VIEW {view_name} AS SELECT * FROM test_get_table_names") + conn.execute(sqla.text(f"CREATE VIEW {view_name} AS SELECT * FROM test_get_table_names")) assert sqla.inspect(engine).get_table_names(schema_name) == ['test_get_table_names'] finally: - conn.execute(f"DROP VIEW IF EXISTS {view_name}") + conn.execute(sqla.text(f"DROP VIEW IF EXISTS {view_name}")) metadata.drop_all(engine) @@ -411,7 +417,8 @@ def test_get_view_names(trino_connection, schema): metadata = sqla.MetaData(schema=schema_name) if not engine.dialect.has_schema(conn, schema_name): - engine.execute(sqla.schema.CreateSchema(schema_name)) + with engine.begin() as connection: + connection.execute(sqla.schema.CreateSchema(schema_name)) try: sqla.Table( @@ -421,10 +428,10 @@ def test_get_view_names(trino_connection, schema): ) metadata.create_all(engine) view_name = schema_name + ".test_get_view_names" - conn.execute(f"CREATE VIEW {view_name} AS SELECT * FROM test_table") + conn.execute(sqla.text(f"CREATE VIEW {view_name} AS SELECT * FROM test_table")) assert sqla.inspect(engine).get_view_names(schema_name) == ['test_get_view_names'] finally: - conn.execute(f"DROP VIEW IF EXISTS {view_name}") + conn.execute(sqla.text(f"DROP VIEW IF EXISTS {view_name}")) metadata.drop_all(engine) diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index 2aaac3de..1fa6ed05 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -156,7 +156,7 @@ def _get_columns(self, connection: Connection, table_name: str, schema: str = No ORDER BY "ordinal_position" ASC """ ).strip() - res = connection.execute(sql.text(query), schema=schema, table=table_name) + res = connection.execute(sql.text(query), {"schema": schema, "table": table_name}) columns = [] for record in res: column = dict( @@ -204,7 +204,7 @@ def get_table_names(self, connection: Connection, schema: str = None, **kw) -> L AND "table_type" = 'BASE TABLE' """ ).strip() - res = connection.execute(sql.text(query), schema=schema) + res = connection.execute(sql.text(query), {"schema": schema}) return [row.table_name for row in res] def get_temp_table_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: @@ -225,7 +225,7 @@ def get_view_names(self, connection: Connection, schema: str = None, **kw) -> Li AND "table_type" = 'VIEW' """ ).strip() - res = connection.execute(sql.text(query), schema=schema) + res = connection.execute(sql.text(query), {"schema": schema}) return [row.table_name for row in res] def get_temp_view_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: @@ -244,7 +244,7 @@ def get_view_definition(self, connection: Connection, view_name: str, schema: st AND "table_name" = :view """ ).strip() - res = connection.execute(sql.text(query), schema=schema, view=view_name) + res = connection.execute(sql.text(query), {"schema": schema, "view": view_name}) return res.scalar() def get_indexes(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: @@ -296,7 +296,7 @@ def get_table_comment(self, connection: Connection, table_name: str, schema: str try: res = connection.execute( sql.text(query), - catalog_name=catalog_name, schema_name=schema_name, table_name=table_name + {"catalog_name": catalog_name, "schema_name": schema_name, "table_name": table_name} ) return dict(text=res.scalar()) except error.TrinoQueryError as e: @@ -314,7 +314,7 @@ def has_schema(self, connection: Connection, schema: str) -> bool: WHERE "schema_name" = :schema """ ).strip() - res = connection.execute(sql.text(query), schema=schema) + res = connection.execute(sql.text(query), {"schema": schema}) return res.first() is not None def has_table(self, connection: Connection, table_name: str, schema: str = None, **kw) -> bool: @@ -329,7 +329,7 @@ def has_table(self, connection: Connection, table_name: str, schema: str = None, AND "table_name" = :table """ ).strip() - res = connection.execute(sql.text(query), schema=schema, table=table_name) + res = connection.execute(sql.text(query), {"schema": schema, "table": table_name}) return res.first() is not None def has_sequence(self, connection: Connection, sequence_name: str, schema: str = None, **kw) -> bool: @@ -363,11 +363,6 @@ def do_execute( self, cursor: Cursor, statement: str, parameters: Tuple[Any, ...], context: DefaultExecutionContext = None ): cursor.execute(statement, parameters) - if context and context.should_autocommit: - # SQL statement only submitted to Trino server when cursor.fetch*() is called. - # For DDL (CREATE/ALTER/DROP) and DML (INSERT/UPDATE/DELETE) statement, call cursor.description - # to force submit statement immediately. - cursor.description # noqa def do_rollback(self, dbapi_connection: trino_dbapi.Connection): if dbapi_connection.transaction is not None: