From 644ef91cb50488465a31aa9bbb035e87ebdb5e12 Mon Sep 17 00:00:00 2001 From: Ben Cassell Date: Mon, 12 Feb 2024 16:32:14 -0800 Subject: [PATCH] Get ready for test automation formatting Make sql client change backwards compatible, fix broken tests fix the integration workflow and linting fix broken tests fix path lint + install missing sqlalchemy library skip cache for one run fix integration so that it doesn't uninstall sqlalchemy Signed-off-by: Ben Cassell --- .github/workflows/code-quality-checks.yml | 4 +- .github/workflows/integration.yml | 59 ++ conftest.py | 44 ++ pyproject.toml | 1 + src/databricks/sql/client.py | 9 +- src/databricks/sqlalchemy/pytest.ini | 3 - .../sqlalchemy/test_local/conftest.py | 44 ++ .../sqlalchemy/test_local/e2e/test_basic.py | 53 +- .../sqlalchemy/test_local/test_parsing.py | 8 +- test.env.example | 12 +- tests/e2e/common/core_tests.py | 167 +++-- tests/e2e/common/decimal_tests.py | 53 +- tests/e2e/common/large_queries_mixin.py | 39 +- tests/e2e/common/retry_test_mixins.py | 148 ++--- tests/e2e/common/staging_ingestion_tests.py | 138 ++-- tests/e2e/common/timestamp_tests.py | 65 +- tests/e2e/common/uc_volume_tests.py | 130 ++-- tests/e2e/test_complex_types.py | 8 +- tests/e2e/test_driver.py | 606 +++++++++--------- tests/e2e/test_parameterized_queries.py | 41 +- 20 files changed, 916 insertions(+), 716 deletions(-) create mode 100644 .github/workflows/integration.yml create mode 100644 conftest.py delete mode 100644 src/databricks/sqlalchemy/pytest.ini create mode 100644 src/databricks/sqlalchemy/test_local/conftest.py diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index bfb8ca942..03c3991d0 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -1,5 +1,5 @@ name: Code Quality Checks -on: +on: push: branches: - main @@ -157,7 +157,7 @@ jobs: - name: Install library run: poetry install --no-interaction #---------------------------------------------- - # black the code + # mypy the code #---------------------------------------------- - name: Mypy run: poetry run mypy --install-types --non-interactive src diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml new file mode 100644 index 000000000..f28c22a85 --- /dev/null +++ b/.github/workflows/integration.yml @@ -0,0 +1,59 @@ +name: Integration Tests +on: + push: + paths-ignore: + - "**.MD" + - "**.md" + +jobs: + run-e2e-tests: + runs-on: ubuntu-latest + environment: azure-prod + env: + DATABRICKS_SERVER_HOSTNAME: ${{ secrets.DATABRICKS_HOST }} + DATABRICKS_HTTP_PATH: ${{ secrets.TEST_PECO_WAREHOUSE_HTTP_PATH }} + DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }} + DATABRICKS_CATALOG: peco + DATABRICKS_USER: ${{ secrets.TEST_PECO_SP_ID }} + steps: + #---------------------------------------------- + # check-out repo and set-up python + #---------------------------------------------- + - name: Check out repository + uses: actions/checkout@v3 + - name: Set up python + id: setup-python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + #---------------------------------------------- + # ----- install & configure poetry ----- + #---------------------------------------------- + - name: Install Poetry + uses: snok/install-poetry@v1 + with: + virtualenvs-create: true + virtualenvs-in-project: true + installer-parallel: true + + #---------------------------------------------- + # load cached venv if cache exists + #---------------------------------------------- + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@v2 + with: + path: .venv + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} + #---------------------------------------------- + # install dependencies if cache does not exist + #---------------------------------------------- + - name: Install dependencies + run: poetry install --no-interaction --all-extras + #---------------------------------------------- + # run test suite + #---------------------------------------------- + - name: Run e2e tests + run: poetry run python -m pytest tests/e2e + - name: Run SQL Alchemy tests + run: poetry run python -m pytest src/databricks/sqlalchemy/test_local diff --git a/conftest.py b/conftest.py new file mode 100644 index 000000000..c8b350bee --- /dev/null +++ b/conftest.py @@ -0,0 +1,44 @@ +import os +import pytest + + +@pytest.fixture(scope="session") +def host(): + return os.getenv("DATABRICKS_SERVER_HOSTNAME") + + +@pytest.fixture(scope="session") +def http_path(): + return os.getenv("DATABRICKS_HTTP_PATH") + + +@pytest.fixture(scope="session") +def access_token(): + return os.getenv("DATABRICKS_TOKEN") + + +@pytest.fixture(scope="session") +def ingestion_user(): + return os.getenv("DATABRICKS_USER") + + +@pytest.fixture(scope="session") +def catalog(): + return os.getenv("DATABRICKS_CATALOG") + + +@pytest.fixture(scope="session") +def schema(): + return os.getenv("DATABRICKS_SCHEMA", "default") + + +@pytest.fixture(scope="session", autouse=True) +def connection_details(host, http_path, access_token, ingestion_user, catalog, schema): + return { + "host": host, + "http_path": http_path, + "access_token": access_token, + "ingestion_user": ingestion_user, + "catalog": catalog, + "schema": schema, + } diff --git a/pyproject.toml b/pyproject.toml index 2771b3ae0..4a5b417eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ exclude = ['ttypes\.py$', 'TCLIService\.py$'] exclude = '/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist|thrift_api)/' [tool.pytest.ini_options] +markers = {"reviewed" = "Test case has been reviewed by Databricks"} minversion = "6.0" log_cli = "false" log_cli_level = "INFO" diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 45f116f0a..313deb24e 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -605,12 +605,15 @@ def _handle_staging_operation( "Local file operations are restricted to paths within the configured staging_allowed_local_path" ) - # TODO: Experiment with DBR sending real headers. - # The specification says headers will be in JSON format but the current null value is actually an empty list [] + # May be real headers, or could be json string + headers = ( + json.loads(row.headers) if isinstance(row.headers, str) else row.headers + ) + handler_args = { "presigned_url": row.presignedUrl, "local_file": abs_localFile, - "headers": json.loads(row.headers or "{}"), + "headers": dict(headers) or {}, } logger.debug( diff --git a/src/databricks/sqlalchemy/pytest.ini b/src/databricks/sqlalchemy/pytest.ini deleted file mode 100644 index affffd2f8..000000000 --- a/src/databricks/sqlalchemy/pytest.ini +++ /dev/null @@ -1,3 +0,0 @@ -[pytest] -markers = - reviewed: Test case has been reviewed by databricks \ No newline at end of file diff --git a/src/databricks/sqlalchemy/test_local/conftest.py b/src/databricks/sqlalchemy/test_local/conftest.py new file mode 100644 index 000000000..c8b350bee --- /dev/null +++ b/src/databricks/sqlalchemy/test_local/conftest.py @@ -0,0 +1,44 @@ +import os +import pytest + + +@pytest.fixture(scope="session") +def host(): + return os.getenv("DATABRICKS_SERVER_HOSTNAME") + + +@pytest.fixture(scope="session") +def http_path(): + return os.getenv("DATABRICKS_HTTP_PATH") + + +@pytest.fixture(scope="session") +def access_token(): + return os.getenv("DATABRICKS_TOKEN") + + +@pytest.fixture(scope="session") +def ingestion_user(): + return os.getenv("DATABRICKS_USER") + + +@pytest.fixture(scope="session") +def catalog(): + return os.getenv("DATABRICKS_CATALOG") + + +@pytest.fixture(scope="session") +def schema(): + return os.getenv("DATABRICKS_SCHEMA", "default") + + +@pytest.fixture(scope="session", autouse=True) +def connection_details(host, http_path, access_token, ingestion_user, catalog, schema): + return { + "host": host, + "http_path": http_path, + "access_token": access_token, + "ingestion_user": ingestion_user, + "catalog": catalog, + "schema": schema, + } diff --git a/src/databricks/sqlalchemy/test_local/e2e/test_basic.py b/src/databricks/sqlalchemy/test_local/e2e/test_basic.py index ec54c2821..ce0b5d894 100644 --- a/src/databricks/sqlalchemy/test_local/e2e/test_basic.py +++ b/src/databricks/sqlalchemy/test_local/e2e/test_basic.py @@ -1,6 +1,5 @@ import datetime import decimal -import os from typing import Tuple, Union, List from unittest import skipIf @@ -19,7 +18,7 @@ from sqlalchemy.engine.reflection import Inspector from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column from sqlalchemy.schema import DropColumnComment, SetColumnComment -from sqlalchemy.types import BOOLEAN, DECIMAL, Date, DateTime, Integer, String +from sqlalchemy.types import BOOLEAN, DECIMAL, Date, Integer, String try: from sqlalchemy.orm import declarative_base @@ -49,12 +48,12 @@ def version_agnostic_select(object_to_select, *args, **kwargs): return select(object_to_select, *args, **kwargs) -def version_agnostic_connect_arguments(catalog=None, schema=None) -> Tuple[str, dict]: - HOST = os.environ.get("host") - HTTP_PATH = os.environ.get("http_path") - ACCESS_TOKEN = os.environ.get("access_token") - CATALOG = catalog or os.environ.get("catalog") - SCHEMA = schema or os.environ.get("schema") +def version_agnostic_connect_arguments(connection_details) -> Tuple[str, dict]: + HOST = connection_details["host"] + HTTP_PATH = connection_details["http_path"] + ACCESS_TOKEN = connection_details["access_token"] + CATALOG = connection_details["catalog"] + SCHEMA = connection_details["schema"] ua_connect_args = {"_user_agent_entry": USER_AGENT_TOKEN} @@ -77,8 +76,8 @@ def version_agnostic_connect_arguments(catalog=None, schema=None) -> Tuple[str, @pytest.fixture -def db_engine() -> Engine: - conn_string, connect_args = version_agnostic_connect_arguments() +def db_engine(connection_details) -> Engine: + conn_string, connect_args = version_agnostic_connect_arguments(connection_details) return create_engine(conn_string, connect_args=connect_args) @@ -92,10 +91,11 @@ def run_query(db_engine: Engine, query: Union[str, Text]): @pytest.fixture -def samples_engine() -> Engine: - conn_string, connect_args = version_agnostic_connect_arguments( - catalog="samples", schema="nyctaxi" - ) +def samples_engine(connection_details) -> Engine: + details = connection_details.copy() + details["catalog"] = "samples" + details["schema"] = "nyctaxi" + conn_string, connect_args = version_agnostic_connect_arguments(details) return create_engine(conn_string, connect_args=connect_args) @@ -141,7 +141,7 @@ def test_connect_args(db_engine): def test_pandas_upload(db_engine, metadata_obj): import pandas as pd - SCHEMA = os.environ.get("schema") + SCHEMA = "default" try: df = pd.read_excel( "src/databricks/sqlalchemy/test_local/e2e/demo_data/MOCK_DATA.xlsx" @@ -409,7 +409,9 @@ def test_get_table_names_smoke_test(samples_engine: Engine): _names is not None, "get_table_names did not succeed" -def test_has_table_across_schemas(db_engine: Engine, samples_engine: Engine): +def test_has_table_across_schemas( + db_engine: Engine, samples_engine: Engine, catalog: str, schema: str +): """For this test to pass these conditions must be met: - Table samples.nyctaxi.trips must exist - Table samples.tpch.customer must exist @@ -426,9 +428,6 @@ def test_has_table_across_schemas(db_engine: Engine, samples_engine: Engine): ) # 3) Check for a table within a different catalog - other_catalog = os.environ.get("catalog") - other_schema = os.environ.get("schema") - # Create a table in a different catalog with db_engine.connect() as conn: conn.execute(text("CREATE TABLE test_has_table (numbers_are_cool INT);")) @@ -442,8 +441,8 @@ def test_has_table_across_schemas(db_engine: Engine, samples_engine: Engine): assert samples_engine.dialect.has_table( connection=conn, table_name="test_has_table", - schema=other_schema, - catalog=other_catalog, + schema=schema, + catalog=catalog, ) finally: conn.execute(text("DROP TABLE test_has_table;")) @@ -503,12 +502,12 @@ def test_get_columns(db_engine, sample_table: str): class TestCommentReflection: @pytest.fixture(scope="class") - def engine(self): - HOST = os.environ.get("host") - HTTP_PATH = os.environ.get("http_path") - ACCESS_TOKEN = os.environ.get("access_token") - CATALOG = os.environ.get("catalog") - SCHEMA = os.environ.get("schema") + def engine(self, connection_details: dict): + HOST = connection_details["host"] + HTTP_PATH = connection_details["http_path"] + ACCESS_TOKEN = connection_details["access_token"] + CATALOG = connection_details["catalog"] + SCHEMA = connection_details["schema"] connection_string = f"databricks://token:{ACCESS_TOKEN}@{HOST}?http_path={HTTP_PATH}&catalog={CATALOG}&schema={SCHEMA}" connect_args = {"_user_agent_entry": USER_AGENT_TOKEN} diff --git a/src/databricks/sqlalchemy/test_local/test_parsing.py b/src/databricks/sqlalchemy/test_local/test_parsing.py index 70e6337a8..c8ab443d0 100644 --- a/src/databricks/sqlalchemy/test_local/test_parsing.py +++ b/src/databricks/sqlalchemy/test_local/test_parsing.py @@ -64,16 +64,16 @@ def test_extract_3l_namespace_from_bad_constraint_string(): extract_three_level_identifier_from_constraint_string(input) -@pytest.mark.parametrize("schema", [None, "some_schema"]) -def test_build_fk_dict(schema): +@pytest.mark.parametrize("tschema", [None, "some_schema"]) +def test_build_fk_dict(tschema): fk_constraint_string = "FOREIGN KEY (`parent_user_id`) REFERENCES `main`.`some_schema`.`users` (`user_id`)" - result = build_fk_dict("some_fk_name", fk_constraint_string, schema_name=schema) + result = build_fk_dict("some_fk_name", fk_constraint_string, schema_name=tschema) assert result == { "name": "some_fk_name", "constrained_columns": ["parent_user_id"], - "referred_schema": schema, + "referred_schema": tschema, "referred_table": "users", "referred_columns": ["user_id"], } diff --git a/test.env.example b/test.env.example index 94aed419a..3109f6017 100644 --- a/test.env.example +++ b/test.env.example @@ -1,11 +1,11 @@ # Authentication details for running e2e tests -host="" -http_path="" -access_token="" +DATABRICKS_SERVER_HOSTNAME= +DATABRICKS_HTTP_PATH= +DATABRICKS_TOKEN= # Only required to run the PySQLStagingIngestionTestSuite -staging_ingestion_user="" +DATABRICKS_USER= # Only required to run SQLAlchemy tests -catalog="" -schema="" \ No newline at end of file +DATABRICKS_CATALOG= +DATABRICKS_SCHEMA= \ No newline at end of file diff --git a/tests/e2e/common/core_tests.py b/tests/e2e/common/core_tests.py index cd325e8d0..e89289efc 100644 --- a/tests/e2e/common/core_tests.py +++ b/tests/e2e/common/core_tests.py @@ -3,14 +3,17 @@ from collections import namedtuple TypeFailure = namedtuple( - "TypeFailure", "query,columnType,resultType,resultValue," - "actualValue,actualType,description,conf") + "TypeFailure", + "query,columnType,resultType,resultValue," "actualValue,actualType,description,conf", +) ResultFailure = namedtuple( - "ResultFailure", "query,columnType,resultType,resultValue," - "actualValue,actualType,description,conf") + "ResultFailure", + "query,columnType,resultType,resultValue," "actualValue,actualType,description,conf", +) ExecFailure = namedtuple( - "ExecFailure", "query,columnType,resultType,resultValue," - "actualValue,actualType,description,conf,error") + "ExecFailure", + "query,columnType,resultType,resultValue," "actualValue,actualType,description,conf,error", +) class SmokeTestMixin: @@ -18,8 +21,8 @@ def test_smoke_test(self): with self.cursor() as cursor: cursor.execute("select 0") rows = cursor.fetchall() - self.assertEqual(len(rows), 1) - self.assertEqual(rows[0][0], 0) + assert len(rows) == 1 + assert rows[0][0] == 0 class CoreTestMixin: @@ -32,69 +35,109 @@ class CoreTestMixin: # A list of (subquery, column_type, python_type, expected_result) # To be executed as "SELECT {} FROM RANGE(...)" and "SELECT {}" range_queries = [ - ("TRUE", 'boolean', bool, True), - ("cast(1 AS TINYINT)", 'byte', int, 1), - ("cast(1000 AS SMALLINT)", 'short', int, 1000), - ("cast(100000 AS INTEGER)", 'integer', int, 100000), - ("cast(10000000000000 AS BIGINT)", 'long', int, 10000000000000), - ("cast(100.001 AS DECIMAL(6, 3))", 'decimal', decimal.Decimal, 100.001), - ("date '2020-02-20'", 'date', datetime.date, datetime.date(2020, 2, 20)), - ("unhex('f000')", 'binary', bytes, b'\xf0\x00'), # pyodbc internal mismatch - ("'foo'", 'string', str, 'foo'), + ("TRUE", "boolean", bool, True), + ("cast(1 AS TINYINT)", "byte", int, 1), + ("cast(1000 AS SMALLINT)", "short", int, 1000), + ("cast(100000 AS INTEGER)", "integer", int, 100000), + ("cast(10000000000000 AS BIGINT)", "long", int, 10000000000000), + ("cast(100.001 AS DECIMAL(6, 3))", "decimal", decimal.Decimal, 100.001), + ("date '2020-02-20'", "date", datetime.date, datetime.date(2020, 2, 20)), + ("unhex('f000')", "binary", bytes, b"\xf0\x00"), # pyodbc internal mismatch + ("'foo'", "string", str, "foo"), # SPARK-32130: 6.x: "4 weeks 2 days" vs 7.x: "30 days" # ("interval 30 days", str, str, "interval 4 weeks 2 days"), # ("interval 3 days", str, str, "interval 3 days"), - ("CAST(NULL AS DOUBLE)", 'double', type(None), None), + ("CAST(NULL AS DOUBLE)", "double", type(None), None), ] # Full queries, only the first column of the first row is checked - queries = [("NULL UNION (SELECT 1) order by 1", 'integer', type(None), None)] + queries = [("NULL UNION (SELECT 1) order by 1", "integer", type(None), None)] def run_tests_on_queries(self, default_conf): failures = [] - for (query, columnType, rowValueType, answer) in self.range_queries: + for query, columnType, rowValueType, answer in self.range_queries: with self.cursor(default_conf) as cursor: failures.extend( - self.run_query(cursor, query, columnType, rowValueType, answer, default_conf)) + self.run_query(cursor, query, columnType, rowValueType, answer, default_conf) + ) failures.extend( - self.run_range_query(cursor, query, columnType, rowValueType, answer, - default_conf)) + self.run_range_query( + cursor, query, columnType, rowValueType, answer, default_conf + ) + ) - for (query, columnType, rowValueType, answer) in self.queries: + for query, columnType, rowValueType, answer in self.queries: with self.cursor(default_conf) as cursor: failures.extend( - self.run_query(cursor, query, columnType, rowValueType, answer, default_conf)) + self.run_query(cursor, query, columnType, rowValueType, answer, default_conf) + ) if failures: - self.fail("Failed testing result set with Arrow. " - "Failed queries: {}".format("\n\n".join([str(f) for f in failures]))) + self.fail( + "Failed testing result set with Arrow. " + "Failed queries: {}".format("\n\n".join([str(f) for f in failures])) + ) def run_query(self, cursor, query, columnType, rowValueType, answer, conf): full_query = "SELECT {}".format(query) expected_column_types = self.expected_column_types(columnType) try: cursor.execute(full_query) - (result, ) = cursor.fetchone() + (result,) = cursor.fetchone() if not all(cursor.description[0][1] == type for type in expected_column_types): return [ - TypeFailure(full_query, expected_column_types, rowValueType, answer, result, - type(result), cursor.description, conf) + TypeFailure( + full_query, + expected_column_types, + rowValueType, + answer, + result, + type(result), + cursor.description, + conf, + ) ] if self.validate_row_value_type and type(result) is not rowValueType: return [ - TypeFailure(full_query, expected_column_types, rowValueType, answer, result, - type(result), cursor.description, conf) + TypeFailure( + full_query, + expected_column_types, + rowValueType, + answer, + result, + type(result), + cursor.description, + conf, + ) ] if self.validate_result and str(answer) != str(result): return [ - ResultFailure(full_query, query, expected_column_types, rowValueType, answer, - result, type(result), cursor.description, conf) + ResultFailure( + full_query, + query, + expected_column_types, + rowValueType, + answer, + result, + type(result), + cursor.description, + conf, + ) ] return [] except Exception as e: return [ - ExecFailure(full_query, columnType, rowValueType, None, None, None, - cursor.description, conf, e) + ExecFailure( + full_query, + columnType, + rowValueType, + None, + None, + None, + cursor.description, + conf, + e, + ) ] def run_range_query(self, cursor, query, columnType, rowValueType, expected, conf): @@ -109,23 +152,55 @@ def run_range_query(self, cursor, query, columnType, rowValueType, expected, con for index, (result, id) in enumerate(rows): if not all(cursor.description[0][1] == type for type in expected_column_types): return [ - TypeFailure(full_query, expected_column_types, rowValueType, expected, - result, type(result), cursor.description, conf) + TypeFailure( + full_query, + expected_column_types, + rowValueType, + expected, + result, + type(result), + cursor.description, + conf, + ) ] - if self.validate_row_value_type and type(result) \ - is not rowValueType: + if self.validate_row_value_type and type(result) is not rowValueType: return [ - TypeFailure(full_query, expected_column_types, rowValueType, expected, - result, type(result), cursor.description, conf) + TypeFailure( + full_query, + expected_column_types, + rowValueType, + expected, + result, + type(result), + cursor.description, + conf, + ) ] if self.validate_result and str(expected) != str(result): return [ - ResultFailure(full_query, expected_column_types, rowValueType, expected, - result, type(result), cursor.description, conf) + ResultFailure( + full_query, + expected_column_types, + rowValueType, + expected, + result, + type(result), + cursor.description, + conf, + ) ] return [] except Exception as e: return [ - ExecFailure(full_query, columnType, rowValueType, None, None, None, - cursor.description, conf, e) + ExecFailure( + full_query, + columnType, + rowValueType, + None, + None, + None, + cursor.description, + conf, + e, + ) ] diff --git a/tests/e2e/common/decimal_tests.py b/tests/e2e/common/decimal_tests.py index 8051d2a18..5005cdf11 100644 --- a/tests/e2e/common/decimal_tests.py +++ b/tests/e2e/common/decimal_tests.py @@ -1,6 +1,7 @@ from decimal import Decimal import pyarrow +import pytest class DecimalTestsMixin: @@ -9,7 +10,7 @@ class DecimalTestsMixin: ("1000000.0000 AS DECIMAL(11, 4)", Decimal("1000000.0000"), pyarrow.decimal128(11, 4)), ("-10.2343 AS DECIMAL(10, 6)", Decimal("-10.234300"), pyarrow.decimal128(10, 6)), # TODO(SC-90767): Re-enable this test after we have a way of passing `ansi_mode` = False - #("-13872347.2343 AS DECIMAL(10, 10)", None, pyarrow.decimal128(10, 10)), + # ("-13872347.2343 AS DECIMAL(10, 10)", None, pyarrow.decimal128(10, 10)), ("NULL AS DECIMAL(1, 1)", None, pyarrow.decimal128(1, 1)), ("1 AS DECIMAL(1, 0)", Decimal("1"), pyarrow.decimal128(1, 0)), ("0.00000 AS DECIMAL(5, 3)", Decimal("0.000"), pyarrow.decimal128(5, 3)), @@ -17,32 +18,36 @@ class DecimalTestsMixin: ] multi_decimals_and_expected_results = [ - (["1 AS DECIMAL(6, 3)", "100.001 AS DECIMAL(6, 3)", "NULL AS DECIMAL(6, 3)"], - [Decimal("1.00"), Decimal("100.001"), None], pyarrow.decimal128(6, 3)), - (["1 AS DECIMAL(6, 3)", "2 AS DECIMAL(5, 2)"], [Decimal('1.000'), - Decimal('2.000')], pyarrow.decimal128(6, - 3)), + ( + ["1 AS DECIMAL(6, 3)", "100.001 AS DECIMAL(6, 3)", "NULL AS DECIMAL(6, 3)"], + [Decimal("1.00"), Decimal("100.001"), None], + pyarrow.decimal128(6, 3), + ), + ( + ["1 AS DECIMAL(6, 3)", "2 AS DECIMAL(5, 2)"], + [Decimal("1.000"), Decimal("2.000")], + pyarrow.decimal128(6, 3), + ), ] - def test_decimals(self): + @pytest.mark.parametrize("decimal, expected_value, expected_type", decimal_and_expected_results) + def test_decimals(self, decimal, expected_value, expected_type): with self.cursor({}) as cursor: - for (decimal, expected_value, expected_type) in self.decimal_and_expected_results: - query = "SELECT CAST ({})".format(decimal) - with self.subTest(query=query): - cursor.execute(query) - table = cursor.fetchmany_arrow(1) - self.assertEqual(table.field(0).type, expected_type) - self.assertEqual(table.to_pydict().popitem()[1][0], expected_value) + query = "SELECT CAST ({})".format(decimal) + cursor.execute(query) + table = cursor.fetchmany_arrow(1) + assert table.field(0).type == expected_type + assert table.to_pydict().popitem()[1][0] == expected_value - def test_multi_decimals(self): + @pytest.mark.parametrize( + "decimals, expected_values, expected_type", multi_decimals_and_expected_results + ) + def test_multi_decimals(self, decimals, expected_values, expected_type): with self.cursor({}) as cursor: - for (decimals, expected_values, - expected_type) in self.multi_decimals_and_expected_results: - union_str = " UNION ".join(["(SELECT CAST ({}))".format(dec) for dec in decimals]) - query = "SELECT * FROM ({}) ORDER BY 1 NULLS LAST".format(union_str) + union_str = " UNION ".join(["(SELECT CAST ({}))".format(dec) for dec in decimals]) + query = "SELECT * FROM ({}) ORDER BY 1 NULLS LAST".format(union_str) - with self.subTest(query=query): - cursor.execute(query) - table = cursor.fetchall_arrow() - self.assertEqual(table.field(0).type, expected_type) - self.assertEqual(table.to_pydict().popitem()[1], expected_values) + cursor.execute(query) + table = cursor.fetchall_arrow() + assert table.field(0).type == expected_type + assert table.to_pydict().popitem()[1] == expected_values diff --git a/tests/e2e/common/large_queries_mixin.py b/tests/e2e/common/large_queries_mixin.py index 3e1e45bc4..8ec32fd4e 100644 --- a/tests/e2e/common/large_queries_mixin.py +++ b/tests/e2e/common/large_queries_mixin.py @@ -35,8 +35,10 @@ def fetch_rows(self, cursor, row_count, fetchmany_size): num_fetches = max(math.ceil(n / 10000), 1) latency_ms = int((time.time() - start_time) * 1000 / num_fetches), 1 - print('Fetched {} rows with an avg latency of {} per fetch, '.format(n, latency_ms) + - 'assuming 10K fetch size.') + print( + "Fetched {} rows with an avg latency of {} per fetch, ".format(n, latency_ms) + + "assuming 10K fetch size." + ) def test_query_with_large_wide_result_set(self): resultSize = 300 * 1000 * 1000 # 300 MB @@ -50,14 +52,15 @@ def test_query_with_large_wide_result_set(self): self.arraysize = 1000 with self.cursor() as cursor: for lz4_compression in [False, True]: - cursor.connection.lz4_compression=lz4_compression + cursor.connection.lz4_compression = lz4_compression uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)]) - cursor.execute("SELECT id, {uuids} FROM RANGE({rows})".format(uuids=uuids, rows=rows)) - self.assertEqual(lz4_compression, cursor.active_result_set.lz4_compressed) + cursor.execute( + "SELECT id, {uuids} FROM RANGE({rows})".format(uuids=uuids, rows=rows) + ) + assert lz4_compression == cursor.active_result_set.lz4_compressed for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)): - self.assertEqual(row[0], row_id) # Verify no rows are dropped in the middle. - self.assertEqual(len(row[1]), 36) - + assert row[0] == row_id # Verify no rows are dropped in the middle. + assert len(row[1]) == 36 def test_query_with_large_narrow_result_set(self): resultSize = 300 * 1000 * 1000 # 300 MB @@ -71,10 +74,10 @@ def test_query_with_large_narrow_result_set(self): with self.cursor() as cursor: cursor.execute("SELECT * FROM RANGE({rows})".format(rows=rows)) for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)): - self.assertEqual(row[0], row_id) + assert row[0] == row_id def test_long_running_query(self): - """ Incrementally increase query size until it takes at least 5 minutes, + """Incrementally increase query size until it takes at least 5 minutes, and asserts that the query completes successfully. """ minutes = 60 @@ -85,20 +88,24 @@ def test_long_running_query(self): scale_factor = 1 with self.cursor() as cursor: while duration < min_duration: - self.assertLess(scale_factor, 512, msg="Detected infinite loop") + assert scale_factor < 512, "Detected infinite loop" start = time.time() - cursor.execute("""SELECT count(*) + cursor.execute( + """SELECT count(*) FROM RANGE({scale}) x JOIN RANGE({scale0}) y ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%" - """.format(scale=scale_factor * scale0, scale0=scale0)) + """.format( + scale=scale_factor * scale0, scale0=scale0 + ) + ) - n, = cursor.fetchone() - self.assertEqual(n, 0) + (n,) = cursor.fetchone() + assert n == 0 duration = time.time() - start current_fraction = duration / min_duration - print('Took {} s with scale factor={}'.format(duration, scale_factor)) + print("Took {} s with scale factor={}".format(duration, scale_factor)) # Extrapolate linearly to reach 5 min and add 50% padding to push over the limit scale_factor = math.ceil(1.5 * scale_factor / current_fraction) diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index 5305c1243..96364ec6f 100644 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -25,7 +25,7 @@ def test_client_should_retry_automatically_when_getting_429(self): self.assertEqual(rows[0][0], 1) def test_client_should_not_retry_429_if_RateLimitRetry_is_0(self): - with self.assertRaises(self.error_type) as cm: + with pytest.raises(self.error_type) as cm: with self.cursor(self.conf_to_disable_rate_limit_retries) as cursor: for _ in range(10): cursor.execute("SELECT 1") @@ -51,16 +51,14 @@ def test_wait_cluster_startup(self): cursor.fetchall() def _test_retry_disabled_with_message(self, error_msg_substring, exception_type): - with self.assertRaises(exception_type) as cm: + with pytest.raises(exception_type) as cm: with self.connection(self.conf_to_disable_temporarily_unavailable_retries): pass - self.assertIn(error_msg_substring, str(cm.exception)) + assert error_msg_substring in str(cm.exception) @contextmanager -def mocked_server_response( - status: int = 200, headers: dict = {}, redirect_location: str = None -): +def mocked_server_response(status: int = 200, headers: dict = {}, redirect_location: str = None): """Context manager for patching urllib3 responses""" # When mocking mocking a BaseHTTPResponse for urllib3 the mock must include @@ -98,9 +96,7 @@ def mock_sequential_server_responses(responses: List[dict]): # Each resp should have these members: for resp in responses: - _mock = MagicMock( - headers=resp["headers"], msg=resp["headers"], status=resp["status"] - ) + _mock = MagicMock(headers=resp["headers"], msg=resp["headers"], status=resp["status"]) _mock.get_redirect_location.return_value = ( False if resp["redirect_location"] is None else resp["redirect_location"] ) @@ -152,7 +148,7 @@ def test_oserror_retries(self): "urllib3.connectionpool.HTTPSConnectionPool._validate_conn", ) as mock_validate_conn: mock_validate_conn.side_effect = OSError("Some arbitrary network error") - with self.assertRaises(MaxRetryError) as cm: + with pytest.raises(MaxRetryError) as cm: with self.connection(extra_params=self._retry_policy) as conn: pass @@ -165,7 +161,7 @@ def test_retry_max_count_not_exceeded(self): before raising an exception """ with mocked_server_response(status=404) as mock_obj: - with self.assertRaises(MaxRetryError) as cm: + with pytest.raises(MaxRetryError) as cm: with self.connection(extra_params=self._retry_policy) as conn: pass assert mock_obj.return_value.getresponse.call_count == 6 @@ -176,10 +172,10 @@ def test_retry_max_duration_not_exceeded(self): THEN the connector raises a MaxRetryDurationError """ with mocked_server_response(status=429, headers={"Retry-After": "60"}): - with self.assertRaises(RequestError) as cm: + with pytest.raises(RequestError) as cm: with self.connection(extra_params=self._retry_policy) as conn: pass - assert isinstance(cm.exception.args[1], MaxRetryDurationError) + assert isinstance(cm.value.args[1], MaxRetryDurationError) def test_retry_abort_non_recoverable_error(self): """GIVEN the server returns a code 501 @@ -189,10 +185,10 @@ def test_retry_abort_non_recoverable_error(self): # Code 501 is a Not Implemented error with mocked_server_response(status=501): - with self.assertRaises(RequestError) as cm: + with pytest.raises(RequestError) as cm: with self.connection(extra_params=self._retry_policy) as conn: pass - assert isinstance(cm.exception.args[1], NonRecoverableNetworkError) + assert isinstance(cm.value.args[1], NonRecoverableNetworkError) def test_retry_abort_unsafe_execute_statement_retry_condition(self): """GIVEN the server sends a code other than 429 or 503 @@ -203,9 +199,9 @@ def test_retry_abort_unsafe_execute_statement_retry_condition(self): with conn.cursor() as cursor: # Code 502 is a Bad Gateway, which we commonly see in production under heavy load with mocked_server_response(status=502): - with self.assertRaises(RequestError) as cm: + with pytest.raises(RequestError) as cm: cursor.execute("Not a real query") - assert isinstance(cm.exception.args[1], UnsafeToRetryError) + assert isinstance(cm.value.args[1], UnsafeToRetryError) def test_retry_dangerous_codes(self): """GIVEN the server sends a dangerous code and the user forced this to be retryable @@ -227,14 +223,12 @@ def test_retry_dangerous_codes(self): with conn.cursor() as cursor: for dangerous_code in DANGEROUS_CODES: with mocked_server_response(status=dangerous_code): - with self.assertRaises(RequestError) as cm: + with pytest.raises(RequestError) as cm: cursor.execute("Not a real query") - assert isinstance(cm.exception.args[1], UnsafeToRetryError) + assert isinstance(cm.value.args[1], UnsafeToRetryError) # Prove that these codes are retried if forced by the user - with self.connection( - extra_params={**self._retry_policy, **additional_settings} - ) as conn: + with self.connection(extra_params={**self._retry_policy, **additional_settings}) as conn: with conn.cursor() as cursor: for dangerous_code in DANGEROUS_CODES: with mocked_server_response(status=dangerous_code): @@ -262,7 +256,7 @@ def test_retry_safe_execute_statement_retry_condition(self): cursor.execute("This query never reaches the server") assert mock_obj.return_value.getresponse.call_count == 2 - def test_retry_abort_close_session_on_404(self): + def test_retry_abort_close_session_on_404(self, caplog): """GIVEN the connector sends a CloseSession command WHEN server sends a 404 (which is normally retried) THEN nothing is retried because 404 means the session already closed @@ -277,22 +271,10 @@ def test_retry_abort_close_session_on_404(self): with self.connection(extra_params={**self._retry_policy}) as conn: with mock_sequential_server_responses(responses): - with self.assertLogs( - "databricks.sql", - level="INFO", - ) as cm: - conn.close() - expected_message_was_found = False - for log in cm.output: - if expected_message_was_found: - break - target = "Session was closed by a prior request" - expected_message_was_found = target in log - self.assertTrue( - expected_message_was_found, "Did not find expected log messages" - ) - - def test_retry_abort_close_operation_on_404(self): + conn.close() + assert "Session was closed by a prior request" in caplog.text + + def test_retry_abort_close_operation_on_404(self, caplog): """GIVEN the connector sends a CancelOperation command WHEN server sends a 404 (which is normally retried) THEN nothing is retried because 404 means the operation was already canceled @@ -315,20 +297,8 @@ def test_retry_abort_close_operation_on_404(self): # This call guarantees we have an open cursor at the server curs.execute("SELECT 1") with mock_sequential_server_responses(responses): - with self.assertLogs( - "databricks.sql", - level="INFO", - ) as cm: - curs.close() - expected_message_was_found = False - for log in cm.output: - if expected_message_was_found: - break - target = "Operation was canceled by a prior request" - expected_message_was_found = target in log - self.assertTrue( - expected_message_was_found, "Did not find expected log messages" - ) + curs.close() + assert "Operation was canceled by a prior request" in caplog.text def test_retry_max_redirects_raises_too_many_redirects_exception(self): """GIVEN the connector is configured with a custom max_redirects @@ -339,10 +309,8 @@ def test_retry_max_redirects_raises_too_many_redirects_exception(self): max_redirects, expected_call_count = 1, 2 # Code 302 is a redirect - with mocked_server_response( - status=302, redirect_location="/foo.bar" - ) as mock_obj: - with self.assertRaises(MaxRetryError) as cm: + with mocked_server_response(status=302, redirect_location="/foo.bar") as mock_obj: + with pytest.raises(MaxRetryError) as cm: with self.connection( extra_params={ **self._retry_policy, @@ -350,7 +318,7 @@ def test_retry_max_redirects_raises_too_many_redirects_exception(self): } ): pass - assert "too many redirects" == str(cm.exception.reason) + assert "too many redirects" == str(cm.value.reason) # Total call count should be 2 (original + 1 retry) assert mock_obj.return_value.getresponse.call_count == expected_call_count @@ -363,10 +331,8 @@ def test_retry_max_redirects_unset_doesnt_redirect_forever(self): _stop_after_attempts_count is enforced. """ # Code 302 is a redirect - with mocked_server_response( - status=302, redirect_location="/foo.bar/" - ) as mock_obj: - with self.assertRaises(MaxRetryError) as cm: + with mocked_server_response(status=302, redirect_location="/foo.bar/") as mock_obj: + with pytest.raises(MaxRetryError) as cm: with self.connection( extra_params={ **self._retry_policy, @@ -391,51 +357,25 @@ def test_retry_max_redirects_is_bounded_by_stop_after_attempts_count(self): with pytest.raises(RequestError) as cm: with mock_sequential_server_responses(responses): - with self.connection( - extra_params={**self._retry_policy, **additional_settings} - ): + with self.connection(extra_params={**self._retry_policy, **additional_settings}): pass # The error should be the result of the 500, not because of too many requests. assert "too many redirects" not in str(cm.value.message) assert "Error during request to server" in str(cm.value.message) - def test_retry_max_redirects_exceeds_max_attempts_count_warns_user(self): - with self.assertLogs( - "databricks.sql", - level="WARN", - ) as cm: - with self.connection( - extra_params={ - **self._retry_policy, - **{ - "_retry_max_redirects": 100, - "_retry_stop_after_attempts_count": 1, - }, - } - ): - pass - expected_message_was_found = False - for log in cm.output: - if expected_message_was_found: - break - target = "it will have no affect!" - expected_message_was_found = target in log - - assert expected_message_was_found, "Did not find expected log messages" - - def test_retry_legacy_behavior_warns_user(self): - with self.assertLogs( - "databricks.sql", - level="WARN", - ) as cm: - with self.connection( - extra_params={**self._retry_policy, "_enable_v3_retries": False} - ): - expected_message_was_found = False - for log in cm.output: - if expected_message_was_found: - break - target = "Legacy retry behavior is enabled for this connection." - expected_message_was_found = target in log - assert expected_message_was_found, "Did not find expected log messages" + def test_retry_max_redirects_exceeds_max_attempts_count_warns_user(self, caplog): + with self.connection( + extra_params={ + **self._retry_policy, + **{ + "_retry_max_redirects": 100, + "_retry_stop_after_attempts_count": 1, + }, + } + ): + assert "it will have no affect!" in caplog.text + + def test_retry_legacy_behavior_warns_user(self, caplog): + with self.connection(extra_params={**self._retry_policy, "_enable_v3_retries": False}): + assert "Legacy retry behavior is enabled for this connection." in caplog.text diff --git a/tests/e2e/common/staging_ingestion_tests.py b/tests/e2e/common/staging_ingestion_tests.py index 3cdeaff72..d8d0429f8 100644 --- a/tests/e2e/common/staging_ingestion_tests.py +++ b/tests/e2e/common/staging_ingestion_tests.py @@ -5,30 +5,27 @@ import databricks.sql as sql from databricks.sql import Error + @pytest.fixture(scope="module", autouse=True) -def check_staging_ingestion_user(): +def check_staging_ingestion_user(ingestion_user): """This fixture verifies that a staging ingestion user email address is present in the environment and raises an exception if not. The fixture only evaluates when the test _isn't skipped_. """ - staging_ingestion_user = os.getenv("staging_ingestion_user") - - if staging_ingestion_user is None: + if ingestion_user is None: raise ValueError( - "To run this test you must designate a `staging_ingestion_user` environment variable. This will be the user associated with the personal access token." + "To run this test you must designate a `DATABRICKS_USER` environment variable. This will be the user associated with the personal access token." ) + class PySQLStagingIngestionTestSuiteMixin: """Simple namespace for ingestion tests. These should be run against DBR >12.x In addition to connection credentials (host, path, token) this suite requires an env var named staging_ingestion_user""" - staging_ingestion_user = os.getenv("staging_ingestion_user") - - - def test_staging_ingestion_life_cycle(self): + def test_staging_ingestion_life_cycle(self, ingestion_user): """PUT a file into the staging location GET the file from the staging location REMOVE the file from the staging location @@ -47,7 +44,7 @@ def test_staging_ingestion_life_cycle(self): with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: cursor = conn.cursor() - query = f"PUT '{temp_path}' INTO 'stage://tmp/{self.staging_ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" + query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) # GET should succeed @@ -56,7 +53,7 @@ def test_staging_ingestion_life_cycle(self): with self.connection(extra_params={"staging_allowed_local_path": new_temp_path}) as conn: cursor = conn.cursor() - query = f"GET 'stage://tmp/{self.staging_ingestion_user}/tmp/11/15/file1.csv' TO '{new_temp_path}'" + query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' TO '{new_temp_path}'" cursor.execute(query) with open(new_fh, "rb") as fp: @@ -66,26 +63,25 @@ def test_staging_ingestion_life_cycle(self): # REMOVE should succeed - remove_query = ( - f"REMOVE 'stage://tmp/{self.staging_ingestion_user}/tmp/11/15/file1.csv'" - ) + remove_query = f"REMOVE 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv'" with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: cursor = conn.cursor() cursor.execute(remove_query) - # GET after REMOVE should fail + # GET after REMOVE should fail with pytest.raises(Error, match="Staging operation over HTTP was unsuccessful: 404"): cursor = conn.cursor() - query = f"GET 'stage://tmp/{self.staging_ingestion_user}/tmp/11/15/file1.csv' TO '{new_temp_path}'" + query = ( + f"GET 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' TO '{new_temp_path}'" + ) cursor.execute(query) os.remove(temp_path) os.remove(new_temp_path) - - def test_staging_ingestion_put_fails_without_staging_allowed_local_path(self): + def test_staging_ingestion_put_fails_without_staging_allowed_local_path(self, ingestion_user): """PUT operations are not supported unless the connection was built with a parameter called staging_allowed_local_path """ @@ -100,11 +96,12 @@ def test_staging_ingestion_put_fails_without_staging_allowed_local_path(self): with pytest.raises(Error, match="You must provide at least one staging_allowed_local_path"): with self.connection() as conn: cursor = conn.cursor() - query = f"PUT '{temp_path}' INTO 'stage://tmp/{self.staging_ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" + query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) - def test_staging_ingestion_put_fails_if_localFile_not_in_staging_allowed_local_path(self): - + def test_staging_ingestion_put_fails_if_localFile_not_in_staging_allowed_local_path( + self, ingestion_user + ): fh, temp_path = tempfile.mkstemp() @@ -118,15 +115,17 @@ def test_staging_ingestion_put_fails_if_localFile_not_in_staging_allowed_local_p # Add junk to base_path base_path = os.path.join(base_path, "temp") - with pytest.raises(Error, match="Local file operations are restricted to paths within the configured staging_allowed_local_path"): + with pytest.raises( + Error, + match="Local file operations are restricted to paths within the configured staging_allowed_local_path", + ): with self.connection(extra_params={"staging_allowed_local_path": base_path}) as conn: cursor = conn.cursor() - query = f"PUT '{temp_path}' INTO 'stage://tmp/{self.staging_ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" + query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) - def test_staging_ingestion_put_fails_if_file_exists_and_overwrite_not_set(self): - """PUT a file into the staging location twice. First command should succeed. Second should fail. - """ + def test_staging_ingestion_put_fails_if_file_exists_and_overwrite_not_set(self, ingestion_user): + """PUT a file into the staging location twice. First command should succeed. Second should fail.""" fh, temp_path = tempfile.mkstemp() @@ -138,18 +137,18 @@ def test_staging_ingestion_put_fails_if_file_exists_and_overwrite_not_set(self): def perform_put(): with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: cursor = conn.cursor() - query = f"PUT '{temp_path}' INTO 'stage://tmp/{self.staging_ingestion_user}/tmp/12/15/file1.csv'" + query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/12/15/file1.csv'" cursor.execute(query) def perform_remove(): - remove_query = ( - f"REMOVE 'stage://tmp/{self.staging_ingestion_user}/tmp/12/15/file1.csv'" - ) - - with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: - cursor = conn.cursor() - cursor.execute(remove_query) + try: + remove_query = f"REMOVE 'stage://tmp/{ingestion_user}/tmp/12/15/file1.csv'" + with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + cursor = conn.cursor() + cursor.execute(remove_query) + except Exception: + pass # Make sure file does not exist perform_remove() @@ -158,15 +157,16 @@ def perform_remove(): perform_put() # Try to put it again - with pytest.raises(sql.exc.ServerOperationError, match="FILE_IN_STAGING_PATH_ALREADY_EXISTS"): + with pytest.raises( + sql.exc.ServerOperationError, match="FILE_IN_STAGING_PATH_ALREADY_EXISTS" + ): perform_put() # Clean up after ourselves perform_remove() - + def test_staging_ingestion_fails_to_modify_another_staging_user(self): - """The server should only allow modification of the staging_ingestion_user's files - """ + """The server should only allow modification of the staging_ingestion_user's files""" some_other_user = "mary.poppins@databricks.com" @@ -184,9 +184,7 @@ def perform_put(): cursor.execute(query) def perform_remove(): - remove_query = ( - f"REMOVE 'stage://tmp/{some_other_user}/tmp/12/15/file1.csv'" - ) + remove_query = f"REMOVE 'stage://tmp/{some_other_user}/tmp/12/15/file1.csv'" with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: cursor = conn.cursor() @@ -210,7 +208,9 @@ def perform_get(): with pytest.raises(sql.exc.ServerOperationError, match="PERMISSION_DENIED"): perform_get() - def test_staging_ingestion_put_fails_if_absolute_localFile_not_in_staging_allowed_local_path(self): + def test_staging_ingestion_put_fails_if_absolute_localFile_not_in_staging_allowed_local_path( + self, ingestion_user + ): """ This test confirms that staging_allowed_local_path and target_file are resolved into absolute paths. """ @@ -221,33 +221,44 @@ def test_staging_ingestion_put_fails_if_absolute_localFile_not_in_staging_allowe staging_allowed_local_path = "/var/www/html" target_file = "/var/www/html/../html1/not_allowed.html" - with pytest.raises(Error, match="Local file operations are restricted to paths within the configured staging_allowed_local_path"): - with self.connection(extra_params={"staging_allowed_local_path": staging_allowed_local_path}) as conn: + with pytest.raises( + Error, + match="Local file operations are restricted to paths within the configured staging_allowed_local_path", + ): + with self.connection( + extra_params={"staging_allowed_local_path": staging_allowed_local_path} + ) as conn: cursor = conn.cursor() - query = f"PUT '{target_file}' INTO 'stage://tmp/{self.staging_ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" + query = f"PUT '{target_file}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) - def test_staging_ingestion_empty_local_path_fails_to_parse_at_server(self): + def test_staging_ingestion_empty_local_path_fails_to_parse_at_server(self, ingestion_user): staging_allowed_local_path = "/var/www/html" target_file = "" with pytest.raises(Error, match="EMPTY_LOCAL_FILE_IN_STAGING_ACCESS_QUERY"): - with self.connection(extra_params={"staging_allowed_local_path": staging_allowed_local_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": staging_allowed_local_path} + ) as conn: cursor = conn.cursor() - query = f"PUT '{target_file}' INTO 'stage://tmp/{self.staging_ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" + query = f"PUT '{target_file}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) - def test_staging_ingestion_invalid_staging_path_fails_at_server(self): + def test_staging_ingestion_invalid_staging_path_fails_at_server(self, ingestion_user): staging_allowed_local_path = "/var/www/html" target_file = "index.html" with pytest.raises(Error, match="INVALID_STAGING_PATH_IN_STAGING_ACCESS_QUERY"): - with self.connection(extra_params={"staging_allowed_local_path": staging_allowed_local_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": staging_allowed_local_path} + ) as conn: cursor = conn.cursor() - query = f"PUT '{target_file}' INTO 'stageRANDOMSTRINGOFCHARACTERS://tmp/{self.staging_ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" + query = f"PUT '{target_file}' INTO 'stageRANDOMSTRINGOFCHARACTERS://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) - def test_staging_ingestion_supports_multiple_staging_allowed_local_path_values(self): + def test_staging_ingestion_supports_multiple_staging_allowed_local_path_values( + self, ingestion_user + ): """staging_allowed_local_path may be either a path-like object or a list of path-like objects. This test confirms that two configured base paths: @@ -258,31 +269,36 @@ def test_staging_ingestion_supports_multiple_staging_allowed_local_path_values(s def generate_file_and_path_and_queries(): """ - 1. Makes a temp file with some contents. - 2. Write a query to PUT it into a staging location - 3. Write a query to REMOVE it from that location (for cleanup) + 1. Makes a temp file with some contents. + 2. Write a query to PUT it into a staging location + 3. Write a query to REMOVE it from that location (for cleanup) """ fh, temp_path = tempfile.mkstemp() with open(fh, "wb") as fp: original_text = "hello world!".encode("utf-8") fp.write(original_text) - put_query = f"PUT '{temp_path}' INTO 'stage://tmp/{self.staging_ingestion_user}/tmp/11/15/{id(temp_path)}.csv' OVERWRITE" - remove_query = f"REMOVE 'stage://tmp/{self.staging_ingestion_user}/tmp/11/15/{id(temp_path)}.csv'" + put_query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/{id(temp_path)}.csv' OVERWRITE" + remove_query = f"REMOVE 'stage://tmp/{ingestion_user}/tmp/11/15/{id(temp_path)}.csv'" return fh, temp_path, put_query, remove_query fh1, temp_path1, put_query1, remove_query1 = generate_file_and_path_and_queries() fh2, temp_path2, put_query2, remove_query2 = generate_file_and_path_and_queries() fh3, temp_path3, put_query3, remove_query3 = generate_file_and_path_and_queries() - with self.connection(extra_params={"staging_allowed_local_path": [temp_path1, temp_path2]}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": [temp_path1, temp_path2]} + ) as conn: cursor = conn.cursor() cursor.execute(put_query1) cursor.execute(put_query2) - - with pytest.raises(Error, match="Local file operations are restricted to paths within the configured staging_allowed_local_path"): + + with pytest.raises( + Error, + match="Local file operations are restricted to paths within the configured staging_allowed_local_path", + ): cursor.execute(put_query3) # Then clean up the files we made cursor.execute(remove_query1) - cursor.execute(remove_query2) \ No newline at end of file + cursor.execute(remove_query2) diff --git a/tests/e2e/common/timestamp_tests.py b/tests/e2e/common/timestamp_tests.py index 38b14e9e8..f25aed7e7 100644 --- a/tests/e2e/common/timestamp_tests.py +++ b/tests/e2e/common/timestamp_tests.py @@ -1,29 +1,31 @@ import datetime +import pytest + from .predicates import compare_dbr_versions, is_thrift_v5_plus, pysql_has_version class TimestampTestsMixin: - timestamp_and_expected_results = [ - ('2021-09-30 11:27:35.123+04:00', datetime.datetime(2021, 9, 30, 7, 27, 35, 123000)), - ('2021-09-30 11:27:35+04:00', datetime.datetime(2021, 9, 30, 7, 27, 35)), - ('2021-09-30 11:27:35.123', datetime.datetime(2021, 9, 30, 11, 27, 35, 123000)), - ('2021-09-30 11:27:35', datetime.datetime(2021, 9, 30, 11, 27, 35)), - ('2021-09-30 11:27', datetime.datetime(2021, 9, 30, 11, 27)), - ('2021-09-30 11', datetime.datetime(2021, 9, 30, 11)), - ('2021-09-30', datetime.datetime(2021, 9, 30)), - ('2021-09', datetime.datetime(2021, 9, 1)), - ('2021', datetime.datetime(2021, 1, 1)), - ('9999-12-31T15:59:59', datetime.datetime(9999, 12, 31, 15, 59, 59)), - ('9999-99-31T15:59:59', None), + date_and_expected_results = [ + ("2021-09-30", datetime.date(2021, 9, 30)), + ("2021-09", datetime.date(2021, 9, 1)), + ("2021", datetime.date(2021, 1, 1)), + ("9999-12-31", datetime.date(9999, 12, 31)), + ("9999-99-31", None), ] - date_and_expected_results = [ - ('2021-09-30', datetime.date(2021, 9, 30)), - ('2021-09', datetime.date(2021, 9, 1)), - ('2021', datetime.date(2021, 1, 1)), - ('9999-12-31', datetime.date(9999, 12, 31)), - ('9999-99-31', None), + timestamp_and_expected_results = [ + ("2021-09-30 11:27:35.123+04:00", datetime.datetime(2021, 9, 30, 7, 27, 35, 123000)), + ("2021-09-30 11:27:35+04:00", datetime.datetime(2021, 9, 30, 7, 27, 35)), + ("2021-09-30 11:27:35.123", datetime.datetime(2021, 9, 30, 11, 27, 35, 123000)), + ("2021-09-30 11:27:35", datetime.datetime(2021, 9, 30, 11, 27, 35)), + ("2021-09-30 11:27", datetime.datetime(2021, 9, 30, 11, 27)), + ("2021-09-30 11", datetime.datetime(2021, 9, 30, 11)), + ("2021-09-30", datetime.datetime(2021, 9, 30)), + ("2021-09", datetime.datetime(2021, 9, 1)), + ("2021", datetime.datetime(2021, 1, 1)), + ("9999-12-31T15:59:59", datetime.datetime(9999, 12, 31, 15, 59, 59)), + ("9999-99-31T15:59:59", None), ] def should_add_timezone(self): @@ -31,7 +33,7 @@ def should_add_timezone(self): def maybe_add_timezone_to_timestamp(self, ts): """If we're using DBR >= 10.2, then we expect back aware timestamps, so add timezone to `ts` - Otherwise we have naive timestamps, so no change is needed + Otherwise we have naive timestamps, so no change is needed """ if ts and self.should_add_timezone(): return ts.replace(tzinfo=datetime.timezone.utc) @@ -39,19 +41,21 @@ def maybe_add_timezone_to_timestamp(self, ts): return ts def assertTimestampsEqual(self, result, expected): - self.assertEqual(result, self.maybe_add_timezone_to_timestamp(expected)) + assert result == self.maybe_add_timezone_to_timestamp(expected) def multi_query(self, n_rows=10): row_sql = "SELECT " + ", ".join( - ["TIMESTAMP('{}')".format(ts) for (ts, _) in self.timestamp_and_expected_results]) + ["TIMESTAMP('{}')".format(ts) for (ts, _) in self.timestamp_and_expected_results] + ) query = " UNION ALL ".join([row_sql for _ in range(n_rows)]) - expected_matrix = [[dt for (_, dt) in self.timestamp_and_expected_results] - for _ in range(n_rows)] + expected_matrix = [ + [dt for (_, dt) in self.timestamp_and_expected_results] for _ in range(n_rows) + ] return query, expected_matrix def test_timestamps(self): with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: - for (timestamp, expected) in self.timestamp_and_expected_results: + for timestamp, expected in self.timestamp_and_expected_results: cursor.execute("SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp)) result = cursor.fetchone()[0] self.assertTimestampsEqual(result, expected) @@ -62,13 +66,14 @@ def test_multi_timestamps(self): cursor.execute(query) result = cursor.fetchall() # We list-ify the rows because PyHive will return a tuple for a row - self.assertEqual([list(r) for r in result], - [[self.maybe_add_timezone_to_timestamp(ts) for ts in r] - for r in expected]) + assert [list(r) for r in result] == [ + [self.maybe_add_timezone_to_timestamp(ts) for ts in r] for r in expected + ] - def test_dates(self): + @pytest.mark.parametrize("date, expected", date_and_expected_results) + def test_dates(self, date, expected): with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: - for (date, expected) in self.date_and_expected_results: + for date, expected in self.date_and_expected_results: cursor.execute("SELECT DATE('{date}')".format(date=date)) result = cursor.fetchone()[0] - self.assertEqual(result, expected) + assert result == expected diff --git a/tests/e2e/common/uc_volume_tests.py b/tests/e2e/common/uc_volume_tests.py index 7493280e7..21e43036a 100644 --- a/tests/e2e/common/uc_volume_tests.py +++ b/tests/e2e/common/uc_volume_tests.py @@ -5,30 +5,26 @@ import databricks.sql as sql from databricks.sql import Error + @pytest.fixture(scope="module", autouse=True) -def check_catalog_and_schema(): +def check_catalog_and_schema(catalog, schema): """This fixture verifies that a catalog and schema are present in the environment. The fixture only evaluates when the test _isn't skipped_. """ - _catalog = os.getenv("catalog") - _schema = os.getenv("schema") - - if _catalog is None or _schema is None: + if catalog is None or schema is None: raise ValueError( f"UC Volume tests require values for the `catalog` and `schema` environment variables. Found catalog {_catalog} schema {_schema}" ) + class PySQLUCVolumeTestSuiteMixin: """Simple namespace for UC Volume tests. In addition to connection credentials (host, path, token) this suite requires env vars named catalog and schema""" - catalog, schema = os.getenv("catalog"), os.getenv("schema") - - - def test_uc_volume_life_cycle(self): + def test_uc_volume_life_cycle(self, catalog, schema): """PUT a file into the UC Volume GET the file from the UC Volume REMOVE the file from the UC Volume @@ -47,7 +43,9 @@ def test_uc_volume_life_cycle(self): with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: cursor = conn.cursor() - query = f"PUT '{temp_path}' INTO '/Volumes/{self.catalog}/{self.schema}/e2etests/file1.csv' OVERWRITE" + query = ( + f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" + ) cursor.execute(query) # GET should succeed @@ -56,7 +54,7 @@ def test_uc_volume_life_cycle(self): with self.connection(extra_params={"staging_allowed_local_path": new_temp_path}) as conn: cursor = conn.cursor() - query = f"GET '/Volumes/{self.catalog}/{self.schema}/e2etests/file1.csv' TO '{new_temp_path}'" + query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'" cursor.execute(query) with open(new_fh, "rb") as fp: @@ -66,26 +64,23 @@ def test_uc_volume_life_cycle(self): # REMOVE should succeed - remove_query = ( - f"REMOVE '/Volumes/{self.catalog}/{self.schema}/e2etests/file1.csv'" - ) + remove_query = f"REMOVE '/Volumes/{catalog}/{schema}/e2etests/file1.csv'" with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: cursor = conn.cursor() cursor.execute(remove_query) - # GET after REMOVE should fail + # GET after REMOVE should fail with pytest.raises(Error, match="Staging operation over HTTP was unsuccessful: 404"): cursor = conn.cursor() - query = f"GET '/Volumes/{self.catalog}/{self.schema}/e2etests/file1.csv' TO '{new_temp_path}'" + query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'" cursor.execute(query) os.remove(temp_path) os.remove(new_temp_path) - - def test_uc_volume_put_fails_without_staging_allowed_local_path(self): + def test_uc_volume_put_fails_without_staging_allowed_local_path(self, catalog, schema): """PUT operations are not supported unless the connection was built with a parameter called staging_allowed_local_path """ @@ -100,11 +95,12 @@ def test_uc_volume_put_fails_without_staging_allowed_local_path(self): with pytest.raises(Error, match="You must provide at least one staging_allowed_local_path"): with self.connection() as conn: cursor = conn.cursor() - query = f"PUT '{temp_path}' INTO '/Volumes/{self.catalog}/{self.schema}/e2etests/file1.csv' OVERWRITE" + query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" cursor.execute(query) - def test_uc_volume_put_fails_if_localFile_not_in_staging_allowed_local_path(self): - + def test_uc_volume_put_fails_if_localFile_not_in_staging_allowed_local_path( + self, catalog, schema + ): fh, temp_path = tempfile.mkstemp() @@ -118,17 +114,18 @@ def test_uc_volume_put_fails_if_localFile_not_in_staging_allowed_local_path(self # Add junk to base_path base_path = os.path.join(base_path, "temp") - with pytest.raises(Error, match="Local file operations are restricted to paths within the configured staging_allowed_local_path"): + with pytest.raises( + Error, + match="Local file operations are restricted to paths within the configured staging_allowed_local_path", + ): with self.connection(extra_params={"staging_allowed_local_path": base_path}) as conn: cursor = conn.cursor() - query = f"PUT '{temp_path}' INTO '/Volumes/{self.catalog}/{self.schema}/e2etests/file1.csv' OVERWRITE" + query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" cursor.execute(query) - def test_uc_volume_put_fails_if_file_exists_and_overwrite_not_set(self): - """PUT a file into the staging location twice. First command should succeed. Second should fail. - """ + def test_uc_volume_put_fails_if_file_exists_and_overwrite_not_set(self, catalog, schema): + """PUT a file into the staging location twice. First command should succeed. Second should fail.""" - fh, temp_path = tempfile.mkstemp() original_text = "hello world!".encode("utf-8") @@ -139,18 +136,18 @@ def test_uc_volume_put_fails_if_file_exists_and_overwrite_not_set(self): def perform_put(): with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: cursor = conn.cursor() - query = f"PUT '{temp_path}' INTO '/Volumes/{self.catalog}/{self.schema}/e2etests/file1.csv'" + query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv'" cursor.execute(query) def perform_remove(): - remove_query = ( - f"REMOVE '/Volumes/{self.catalog}/{self.schema}/e2etests/file1.csv'" - ) - - with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: - cursor = conn.cursor() - cursor.execute(remove_query) + try: + remove_query = f"REMOVE '/Volumes/{catalog}/{schema}/e2etests/file1.csv'" + with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + cursor = conn.cursor() + cursor.execute(remove_query) + except Exception: + pass # Make sure file does not exist perform_remove() @@ -159,14 +156,17 @@ def perform_remove(): perform_put() # Try to put it again - with pytest.raises(sql.exc.ServerOperationError, match="FILE_IN_STAGING_PATH_ALREADY_EXISTS"): + with pytest.raises( + sql.exc.ServerOperationError, match="FILE_IN_STAGING_PATH_ALREADY_EXISTS" + ): perform_put() # Clean up after ourselves perform_remove() - - def test_uc_volume_put_fails_if_absolute_localFile_not_in_staging_allowed_local_path(self): + def test_uc_volume_put_fails_if_absolute_localFile_not_in_staging_allowed_local_path( + self, catalog, schema + ): """ This test confirms that staging_allowed_local_path and target_file are resolved into absolute paths. """ @@ -177,33 +177,42 @@ def test_uc_volume_put_fails_if_absolute_localFile_not_in_staging_allowed_local_ staging_allowed_local_path = "/var/www/html" target_file = "/var/www/html/../html1/not_allowed.html" - with pytest.raises(Error, match="Local file operations are restricted to paths within the configured staging_allowed_local_path"): - with self.connection(extra_params={"staging_allowed_local_path": staging_allowed_local_path}) as conn: + with pytest.raises( + Error, + match="Local file operations are restricted to paths within the configured staging_allowed_local_path", + ): + with self.connection( + extra_params={"staging_allowed_local_path": staging_allowed_local_path} + ) as conn: cursor = conn.cursor() - query = f"PUT '{target_file}' INTO '/Volumes/{self.catalog}/{self.schema}/e2etests/file1.csv' OVERWRITE" + query = f"PUT '{target_file}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" cursor.execute(query) - def test_uc_volume_empty_local_path_fails_to_parse_at_server(self): + def test_uc_volume_empty_local_path_fails_to_parse_at_server(self, catalog, schema): staging_allowed_local_path = "/var/www/html" target_file = "" with pytest.raises(Error, match="EMPTY_LOCAL_FILE_IN_STAGING_ACCESS_QUERY"): - with self.connection(extra_params={"staging_allowed_local_path": staging_allowed_local_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": staging_allowed_local_path} + ) as conn: cursor = conn.cursor() - query = f"PUT '{target_file}' INTO '/Volumes/{self.catalog}/{self.schema}/e2etests/file1.csv' OVERWRITE" + query = f"PUT '{target_file}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" cursor.execute(query) - def test_uc_volume_invalid_volume_path_fails_at_server(self): + def test_uc_volume_invalid_volume_path_fails_at_server(self, catalog, schema): staging_allowed_local_path = "/var/www/html" target_file = "index.html" - with pytest.raises(Error, match="NOT_FOUND: CATALOG"): - with self.connection(extra_params={"staging_allowed_local_path": staging_allowed_local_path}) as conn: + with pytest.raises(Error, match="NOT_FOUND: Catalog"): + with self.connection( + extra_params={"staging_allowed_local_path": staging_allowed_local_path} + ) as conn: cursor = conn.cursor() - query = f"PUT '{target_file}' INTO '/Volumes/RANDOMSTRINGOFCHARACTERS/{self.catalog}/{self.schema}/e2etests/file1.csv' OVERWRITE" + query = f"PUT '{target_file}' INTO '/Volumes/RANDOMSTRINGOFCHARACTERS/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" cursor.execute(query) - def test_uc_volume_supports_multiple_staging_allowed_local_path_values(self): + def test_uc_volume_supports_multiple_staging_allowed_local_path_values(self, catalog, schema): """staging_allowed_local_path may be either a path-like object or a list of path-like objects. This test confirms that two configured base paths: @@ -214,31 +223,36 @@ def test_uc_volume_supports_multiple_staging_allowed_local_path_values(self): def generate_file_and_path_and_queries(): """ - 1. Makes a temp file with some contents. - 2. Write a query to PUT it into a staging location - 3. Write a query to REMOVE it from that location (for cleanup) + 1. Makes a temp file with some contents. + 2. Write a query to PUT it into a staging location + 3. Write a query to REMOVE it from that location (for cleanup) """ fh, temp_path = tempfile.mkstemp() with open(fh, "wb") as fp: original_text = "hello world!".encode("utf-8") fp.write(original_text) - put_query = f"PUT '{temp_path}' INTO '/Volumes/{self.catalog}/{self.schema}/e2etests/{id(temp_path)}.csv' OVERWRITE" - remove_query = f"REMOVE '/Volumes/{self.catalog}/{self.schema}/e2etests/{id(temp_path)}.csv'" + put_query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/{id(temp_path)}.csv' OVERWRITE" + remove_query = f"REMOVE '/Volumes/{catalog}/{schema}/e2etests/{id(temp_path)}.csv'" return fh, temp_path, put_query, remove_query fh1, temp_path1, put_query1, remove_query1 = generate_file_and_path_and_queries() fh2, temp_path2, put_query2, remove_query2 = generate_file_and_path_and_queries() fh3, temp_path3, put_query3, remove_query3 = generate_file_and_path_and_queries() - with self.connection(extra_params={"staging_allowed_local_path": [temp_path1, temp_path2]}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": [temp_path1, temp_path2]} + ) as conn: cursor = conn.cursor() cursor.execute(put_query1) cursor.execute(put_query2) - - with pytest.raises(Error, match="Local file operations are restricted to paths within the configured staging_allowed_local_path"): + + with pytest.raises( + Error, + match="Local file operations are restricted to paths within the configured staging_allowed_local_path", + ): cursor.execute(put_query3) # Then clean up the files we made cursor.execute(remove_query1) - cursor.execute(remove_query2) \ No newline at end of file + cursor.execute(remove_query2) diff --git a/tests/e2e/test_complex_types.py b/tests/e2e/test_complex_types.py index d27e88220..0a7f514a8 100644 --- a/tests/e2e/test_complex_types.py +++ b/tests/e2e/test_complex_types.py @@ -1,4 +1,3 @@ - import pytest from numpy import ndarray @@ -7,7 +6,8 @@ class TestComplexTypes(PySQLPytestTestCase): @pytest.fixture(scope="class") - def table_fixture(self): + def table_fixture(self, connection_details): + self.arguments = connection_details.copy() """A pytest fixture that creates a table with a complex type, inserts a record, yields, and then drops the table""" with self.cursor() as cursor: @@ -53,9 +53,7 @@ def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture): @pytest.mark.parametrize("field", [("array_col"), ("map_col"), ("struct_col")]) def test_read_complex_types_as_string(self, field, table_fixture): """Confirms the return type of a complex type that is returned as a string""" - with self.cursor( - extra_params={"_use_arrow_native_complex_types": False} - ) as cursor: + with self.cursor(extra_params={"_use_arrow_native_complex_types": False}) as cursor: result = cursor.execute( "SELECT * FROM pysql_test_complex_types_table LIMIT 1" ).fetchone() diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index d823a12d7..c23e4f79d 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -19,8 +19,23 @@ from urllib3.connectionpool import ReadTimeoutError import databricks.sql as sql -from databricks.sql import STRING, BINARY, NUMBER, DATETIME, DATE, DatabaseError, Error, OperationalError, RequestError -from tests.e2e.common.predicates import pysql_has_version, pysql_supports_arrow, compare_dbr_versions, is_thrift_v5_plus +from databricks.sql import ( + STRING, + BINARY, + NUMBER, + DATETIME, + DATE, + DatabaseError, + Error, + OperationalError, + RequestError, +) +from tests.e2e.common.predicates import ( + pysql_has_version, + pysql_supports_arrow, + compare_dbr_versions, + is_thrift_v5_plus, +) from tests.e2e.common.core_tests import CoreTestMixin, SmokeTestMixin from tests.e2e.common.large_queries_mixin import LargeQueriesMixin from tests.e2e.common.timestamp_tests import TimestampTestsMixin @@ -40,99 +55,44 @@ unsafe_logger.addHandler(logging.FileHandler("./tests-unsafe.log")) # manually decorate DecimalTestsMixin to need arrow support -for name in loader.getTestCaseNames(DecimalTestsMixin, 'test_'): +for name in loader.getTestCaseNames(DecimalTestsMixin, "test_"): fn = getattr(DecimalTestsMixin, name) - decorated = skipUnless(pysql_supports_arrow(), 'Decimal tests need arrow support')(fn) + decorated = skipUnless(pysql_supports_arrow(), "Decimal tests need arrow support")(fn) setattr(DecimalTestsMixin, name, decorated) -get_args_from_env = True - - -class PySQLTestCase(TestCase): - error_type = Error - conf_to_disable_rate_limit_retries = {"_retry_stop_after_attempts_count": 1} - conf_to_disable_temporarily_unavailable_retries = {"_retry_stop_after_attempts_count": 1} - - def __init__(self, method_name): - super().__init__(method_name) - # If running in local mode, just use environment variables for params. - self.arguments = os.environ if get_args_from_env else {} - self.arraysize = 1000 - self.buffer_size_bytes = 104857600 - - def connection_params(self, arguments): - params = { - "server_hostname": arguments["host"], - "http_path": arguments["http_path"], - **self.auth_params(arguments) - } - - return params - - def auth_params(self, arguments): - return { - "_username": arguments.get("rest_username"), - "_password": arguments.get("rest_password"), - "access_token": arguments.get("access_token") - } - - @contextmanager - def connection(self, extra_params=()): - connection_params = dict(self.connection_params(self.arguments), **dict(extra_params)) - - log.info("Connecting with args: {}".format(connection_params)) - conn = sql.connect(**connection_params) - - try: - yield conn - finally: - conn.close() - - @contextmanager - def cursor(self, extra_params=()): - with self.connection(extra_params) as conn: - cursor = conn.cursor(arraysize=self.arraysize, buffer_size_bytes=self.buffer_size_bytes) - try: - yield cursor - finally: - cursor.close() - - def assertEqualRowValues(self, actual, expected): - self.assertEqual(len(actual) if actual else 0, len(expected) if expected else 0) - for act, exp in zip(actual, expected): - self.assertSequenceEqual(act, exp) - -class PySQLPytestTestCase(): +class PySQLPytestTestCase: """A mirror of PySQLTest case that doesn't inherit from unittest.TestCase so that we can use pytest.mark.parameterize """ + error_type = Error conf_to_disable_rate_limit_retries = {"_retry_stop_after_attempts_count": 1} conf_to_disable_temporarily_unavailable_retries = {"_retry_stop_after_attempts_count": 1} - arguments = os.environ if get_args_from_env else {} arraysize = 1000 buffer_size_bytes = 104857600 - def connection_params(self, arguments): + @pytest.fixture(autouse=True) + def get_details(self, connection_details): + self.arguments = connection_details.copy() + + def connection_params(self): params = { - "server_hostname": arguments["host"], - "http_path": arguments["http_path"], - **self.auth_params(arguments) + "server_hostname": self.arguments["host"], + "http_path": self.arguments["http_path"], + **self.auth_params(), } return params - def auth_params(self, arguments): + def auth_params(self): return { - "_username": arguments.get("rest_username"), - "_password": arguments.get("rest_password"), - "access_token": arguments.get("access_token") + "access_token": self.arguments.get("access_token"), } @contextmanager def connection(self, extra_params=()): - connection_params = dict(self.connection_params(self.arguments), **dict(extra_params)) + connection_params = dict(self.connection_params(), **dict(extra_params)) log.info("Connecting with args: {}".format(connection_params)) conn = sql.connect(**connection_params) @@ -152,13 +112,16 @@ def cursor(self, extra_params=()): cursor.close() def assertEqualRowValues(self, actual, expected): - self.assertEqual(len(actual) if actual else 0, len(expected) if expected else 0) + len_actual = len(actual) if actual else 0 + len_expected = len(expected) if expected else 0 + assert len_actual == len_expected for act, exp in zip(actual, expected): - self.assertSequenceEqual(act, exp) + assert len(act) == len(exp) + for i in range(len(act)): + assert act[i] == exp[i] - -class PySQLLargeQueriesSuite(PySQLTestCase, LargeQueriesMixin): +class TestPySQLLargeQueriesSuite(PySQLPytestTestCase, LargeQueriesMixin): def get_some_rows(self, cursor, fetchmany_size): row = cursor.fetchone() if row: @@ -166,7 +129,8 @@ def get_some_rows(self, cursor, fetchmany_size): else: return None - @skipUnless(pysql_supports_arrow(), 'needs arrow support') + @skipUnless(pysql_supports_arrow(), "needs arrow support") + @pytest.mark.skip("This test requires a previously uploaded data set") def test_cloud_fetch(self): # This test can take several minutes to run limits = [100000, 300000] @@ -176,20 +140,24 @@ def test_cloud_fetch(self): # e2-dogfood host > hive_metastore catalog > main schema has such a table called store_sales. # If this table is deleted or this test is run on a different host, a different table may need to be used. base_query = "SELECT * FROM store_sales WHERE ss_sold_date_sk = 2452234 " - for num_limit, num_threads, lz4_compression in itertools.product(limits, threads, [True, False]): - with self.subTest(num_limit=num_limit, num_threads=num_threads, lz4_compression=lz4_compression): + for num_limit, num_threads, lz4_compression in itertools.product( + limits, threads, [True, False] + ): + with self.subTest( + num_limit=num_limit, num_threads=num_threads, lz4_compression=lz4_compression + ): cf_result, noop_result = None, None query = base_query + "LIMIT " + str(num_limit) - with self.cursor({ - "use_cloud_fetch": True, - "max_download_threads": num_threads, - "catalog": "hive_metastore" - }) as cursor: + with self.cursor( + { + "use_cloud_fetch": True, + "max_download_threads": num_threads, + "catalog": "hive_metastore", + }, + ) as cursor: cursor.execute(query) cf_result = cursor.fetchall() - with self.cursor({ - "catalog": "hive_metastore" - }) as cursor: + with self.cursor({"catalog": "hive_metastore"}) as cursor: cursor.execute(query) noop_result = cursor.fetchall() assert len(cf_result) == len(noop_result) @@ -199,8 +167,16 @@ def test_cloud_fetch(self): # Exclude Retry tests because they require specific setups, and LargeQueries too slow for core # tests -class PySQLCoreTestSuite(SmokeTestMixin, CoreTestMixin, DecimalTestsMixin, TimestampTestsMixin, - PySQLTestCase, PySQLStagingIngestionTestSuiteMixin, PySQLRetryTestsMixin, PySQLUCVolumeTestSuiteMixin): +class TestPySQLCoreSuite( + SmokeTestMixin, + CoreTestMixin, + DecimalTestsMixin, + TimestampTestsMixin, + PySQLPytestTestCase, + PySQLStagingIngestionTestSuiteMixin, + PySQLRetryTestsMixin, + PySQLUCVolumeTestSuiteMixin, +): validate_row_value_type = True validate_result = True @@ -209,21 +185,21 @@ class PySQLCoreTestSuite(SmokeTestMixin, CoreTestMixin, DecimalTestsMixin, Times # - also potentially a PEP-249 object like NUMBER, DATETIME etc. def expected_column_types(self, type_): type_mappings = { - 'boolean': ['boolean', NUMBER], - 'byte': ['tinyint', NUMBER], - 'short': ['smallint', NUMBER], - 'integer': ['int', NUMBER], - 'long': ['bigint', NUMBER], - 'decimal': ['decimal', NUMBER], - 'timestamp': ['timestamp', DATETIME], - 'date': ['date', DATE], - 'binary': ['binary', BINARY], - 'string': ['string', STRING], - 'array': ['array'], - 'struct': ['struct'], - 'map': ['map'], - 'double': ['double', NUMBER], - 'null': ['null'] + "boolean": ["boolean", NUMBER], + "byte": ["tinyint", NUMBER], + "short": ["smallint", NUMBER], + "integer": ["int", NUMBER], + "long": ["bigint", NUMBER], + "decimal": ["decimal", NUMBER], + "timestamp": ["timestamp", DATETIME], + "date": ["date", DATE], + "binary": ["binary", BINARY], + "string": ["string", STRING], + "array": ["array"], + "struct": ["struct"], + "map": ["map"], + "double": ["double", NUMBER], + "null": ["null"], } return type_mappings[type_] @@ -232,7 +208,7 @@ def test_queries(self): array_type = str array_val = "[1,2,3]" struct_type = str - struct_val = "{\"a\":1,\"b\":2}" + struct_val = '{"a":1,"b":2}' map_type = str map_val = "{1:2,3:4}" else: @@ -246,52 +222,56 @@ def test_queries(self): null_type = "null" if float(sql.__version__[0:2]) < 2.0 else "string" self.range_queries = CoreTestMixin.range_queries + [ ("NULL", null_type, type(None), None), - ("array(1, 2, 3)", 'array', array_type, array_val), - ("struct(1 as a, 2 as b)", 'struct', struct_type, struct_val), - ("map(1, 2, 3, 4)", 'map', map_type, map_val), + ("array(1, 2, 3)", "array", array_type, array_val), + ("struct(1 as a, 2 as b)", "struct", struct_type, struct_val), + ("map(1, 2, 3, 4)", "map", map_type, map_val), ] self.run_tests_on_queries({}) - @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + @skipIf(pysql_has_version("<", "2"), "requires pysql v2") def test_incorrect_query_throws_exception(self): with self.cursor({}) as cursor: # Syntax errors should contain the invalid SQL - with self.assertRaises(DatabaseError) as cm: + with pytest.raises(DatabaseError) as cm: cursor.execute("^ FOO BAR") - self.assertIn("FOO BAR", str(cm.exception)) + assert "FOO BAR" in str(cm.value) # Database error should contain the missing database - with self.assertRaises(DatabaseError) as cm: + with pytest.raises(DatabaseError) as cm: cursor.execute("USE foo234823498ydfsiusdhf") - self.assertIn("foo234823498ydfsiusdhf", str(cm.exception)) + assert "foo234823498ydfsiusdhf" in str(cm.value) # SQL with Extraneous input should send back the extraneous input - with self.assertRaises(DatabaseError) as cm: + with pytest.raises(DatabaseError) as cm: cursor.execute("CREATE TABLE IF NOT EXISTS TABLE table_234234234") - self.assertIn("table_234234234", str(cm.exception)) + assert "table_234234234" in str(cm.value) def test_create_table_will_return_empty_result_set(self): with self.cursor({}) as cursor: - table_name = 'table_{uuid}'.format(uuid=str(uuid4()).replace('-', '_')) + table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) try: cursor.execute( "CREATE TABLE IF NOT EXISTS {} AS (SELECT 1 AS col_1, '2' AS col_2)".format( - table_name)) - self.assertEqual(cursor.fetchall(), []) + table_name + ) + ) + assert cursor.fetchall() == [] finally: cursor.execute("DROP TABLE IF EXISTS {}".format(table_name)) def test_get_tables(self): with self.cursor({}) as cursor: - table_name = 'table_{uuid}'.format(uuid=str(uuid4()).replace('-', '_')) - table_names = [table_name + '_1', table_name + '_2'] + table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) + table_names = [table_name + "_1", table_name + "_2"] try: for table in table_names: cursor.execute( "CREATE TABLE IF NOT EXISTS {} AS (SELECT 1 AS col_1, '2' AS col_2)".format( - table)) + table + ) + ) cursor.tables(schema_name="defa%") tables = cursor.fetchall() tables_desc = cursor.description @@ -299,38 +279,42 @@ def test_get_tables(self): for table in table_names: # Test only schema name and table name. # From other columns, what is supported depends on DBR version. - self.assertIn(['default', table], [list(table[1:3]) for table in tables]) - self.assertEqual( - tables_desc, - [('TABLE_CAT', 'string', None, None, None, None, None), - ('TABLE_SCHEM', 'string', None, None, None, None, None), - ('TABLE_NAME', 'string', None, None, None, None, None), - ('TABLE_TYPE', 'string', None, None, None, None, None), - ('REMARKS', 'string', None, None, None, None, None), - ('TYPE_CAT', 'string', None, None, None, None, None), - ('TYPE_SCHEM', 'string', None, None, None, None, None), - ('TYPE_NAME', 'string', None, None, None, None, None), - ('SELF_REFERENCING_COL_NAME', 'string', None, None, None, None, None), - ('REF_GENERATION', 'string', None, None, None, None, None)]) + assert ["default", table] in [list(table[1:3]) for table in tables] + expected = [ + ("TABLE_CAT", "string", None, None, None, None, None), + ("TABLE_SCHEM", "string", None, None, None, None, None), + ("TABLE_NAME", "string", None, None, None, None, None), + ("TABLE_TYPE", "string", None, None, None, None, None), + ("REMARKS", "string", None, None, None, None, None), + ("TYPE_CAT", "string", None, None, None, None, None), + ("TYPE_SCHEM", "string", None, None, None, None, None), + ("TYPE_NAME", "string", None, None, None, None, None), + ("SELF_REFERENCING_COL_NAME", "string", None, None, None, None, None), + ("REF_GENERATION", "string", None, None, None, None, None), + ] + assert tables_desc == expected + finally: for table in table_names: - cursor.execute('DROP TABLE IF EXISTS {}'.format(table)) + cursor.execute("DROP TABLE IF EXISTS {}".format(table)) def test_get_columns(self): with self.cursor({}) as cursor: - table_name = 'table_{uuid}'.format(uuid=str(uuid4()).replace('-', '_')) - table_names = [table_name + '_1', table_name + '_2'] + table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) + table_names = [table_name + "_1", table_name + "_2"] try: for table in table_names: - cursor.execute("CREATE TABLE IF NOT EXISTS {} AS (SELECT " - "1 AS col_1, " - "'2' AS col_2, " - "named_struct('name', 'alice', 'age', 28) as col_3, " - "map('items', 45, 'cost', 228) as col_4, " - "array('item1', 'item2', 'item3') as col_5)".format(table)) - - cursor.columns(schema_name="defa%", table_name=table_name + '%') + cursor.execute( + "CREATE TABLE IF NOT EXISTS {} AS (SELECT " + "1 AS col_1, " + "'2' AS col_2, " + "named_struct('name', 'alice', 'age', 28) as col_3, " + "map('items', 45, 'cost', 228) as col_4, " + "array('item1', 'item2', 'item3') as col_5)".format(table) + ) + + cursor.columns(schema_name="defa%", table_name=table_name + "%") cols = cursor.fetchall() cols_desc = cursor.description @@ -340,117 +324,140 @@ def test_get_columns(self): for col in cleaned_response: col[4] = col[4].replace("`", "") - self.assertEqual(cleaned_response, [ - ['default', table_name + '_1', 'col_1', 4, 'INT'], - ['default', table_name + '_1', 'col_2', 12, 'STRING'], - ['default', table_name + '_1', 'col_3', 2002, 'STRUCT'], - ['default', table_name + '_1', 'col_4', 2000, 'MAP'], - ['default', table_name + '_1', 'col_5', 2003, 'ARRAY'], - ['default', table_name + '_2', 'col_1', 4, 'INT'], - ['default', table_name + '_2', 'col_2', 12, 'STRING'], - ['default', table_name + '_2', 'col_3', 2002, 'STRUCT'], - ['default', table_name + '_2', 'col_4', 2000, 'MAP'], + expected = [ + ["default", table_name + "_1", "col_1", 4, "INT"], + ["default", table_name + "_1", "col_2", 12, "STRING"], + [ + "default", + table_name + "_1", + "col_3", + 2002, + "STRUCT", + ], + ["default", table_name + "_1", "col_4", 2000, "MAP"], + ["default", table_name + "_1", "col_5", 2003, "ARRAY"], + ["default", table_name + "_2", "col_1", 4, "INT"], + ["default", table_name + "_2", "col_2", 12, "STRING"], + [ + "default", + table_name + "_2", + "col_3", + 2002, + "STRUCT", + ], + ["default", table_name + "_2", "col_4", 2000, "MAP"], [ - 'default', - table_name + '_2', - 'col_5', + "default", + table_name + "_2", + "col_5", 2003, - 'ARRAY', - ] - ]) - - self.assertEqual(cols_desc, - [('TABLE_CAT', 'string', None, None, None, None, None), - ('TABLE_SCHEM', 'string', None, None, None, None, None), - ('TABLE_NAME', 'string', None, None, None, None, None), - ('COLUMN_NAME', 'string', None, None, None, None, None), - ('DATA_TYPE', 'int', None, None, None, None, None), - ('TYPE_NAME', 'string', None, None, None, None, None), - ('COLUMN_SIZE', 'int', None, None, None, None, None), - ('BUFFER_LENGTH', 'tinyint', None, None, None, None, None), - ('DECIMAL_DIGITS', 'int', None, None, None, None, None), - ('NUM_PREC_RADIX', 'int', None, None, None, None, None), - ('NULLABLE', 'int', None, None, None, None, None), - ('REMARKS', 'string', None, None, None, None, None), - ('COLUMN_DEF', 'string', None, None, None, None, None), - ('SQL_DATA_TYPE', 'int', None, None, None, None, None), - ('SQL_DATETIME_SUB', 'int', None, None, None, None, None), - ('CHAR_OCTET_LENGTH', 'int', None, None, None, None, None), - ('ORDINAL_POSITION', 'int', None, None, None, None, None), - ('IS_NULLABLE', 'string', None, None, None, None, None), - ('SCOPE_CATALOG', 'string', None, None, None, None, None), - ('SCOPE_SCHEMA', 'string', None, None, None, None, None), - ('SCOPE_TABLE', 'string', None, None, None, None, None), - ('SOURCE_DATA_TYPE', 'smallint', None, None, None, None, None), - ('IS_AUTO_INCREMENT', 'string', None, None, None, None, None)]) + "ARRAY", + ], + ] + assert cleaned_response == expected + expected = [ + ("TABLE_CAT", "string", None, None, None, None, None), + ("TABLE_SCHEM", "string", None, None, None, None, None), + ("TABLE_NAME", "string", None, None, None, None, None), + ("COLUMN_NAME", "string", None, None, None, None, None), + ("DATA_TYPE", "int", None, None, None, None, None), + ("TYPE_NAME", "string", None, None, None, None, None), + ("COLUMN_SIZE", "int", None, None, None, None, None), + ("BUFFER_LENGTH", "tinyint", None, None, None, None, None), + ("DECIMAL_DIGITS", "int", None, None, None, None, None), + ("NUM_PREC_RADIX", "int", None, None, None, None, None), + ("NULLABLE", "int", None, None, None, None, None), + ("REMARKS", "string", None, None, None, None, None), + ("COLUMN_DEF", "string", None, None, None, None, None), + ("SQL_DATA_TYPE", "int", None, None, None, None, None), + ("SQL_DATETIME_SUB", "int", None, None, None, None, None), + ("CHAR_OCTET_LENGTH", "int", None, None, None, None, None), + ("ORDINAL_POSITION", "int", None, None, None, None, None), + ("IS_NULLABLE", "string", None, None, None, None, None), + ("SCOPE_CATALOG", "string", None, None, None, None, None), + ("SCOPE_SCHEMA", "string", None, None, None, None, None), + ("SCOPE_TABLE", "string", None, None, None, None, None), + ("SOURCE_DATA_TYPE", "smallint", None, None, None, None, None), + ("IS_AUTO_INCREMENT", "string", None, None, None, None, None), + ] + assert cols_desc == expected finally: for table in table_names: - cursor.execute('DROP TABLE IF EXISTS {}'.format(table)) + cursor.execute("DROP TABLE IF EXISTS {}".format(table)) def test_escape_single_quotes(self): with self.cursor({}) as cursor: - table_name = 'table_{uuid}'.format(uuid=str(uuid4()).replace('-', '_')) + table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) # Test escape syntax directly - cursor.execute("CREATE TABLE IF NOT EXISTS {} AS (SELECT 'you\\'re' AS col_1)".format(table_name)) + cursor.execute( + "CREATE TABLE IF NOT EXISTS {} AS (SELECT 'you\\'re' AS col_1)".format(table_name) + ) cursor.execute("SELECT * FROM {} WHERE col_1 LIKE 'you\\'re'".format(table_name)) rows = cursor.fetchall() assert rows[0]["col_1"] == "you're" # Test escape syntax in parameter - cursor.execute("SELECT * FROM {} WHERE {}.col_1 LIKE %(var)s".format(table_name, table_name), parameters={"var": "you're"}) + cursor.execute( + "SELECT * FROM {} WHERE {}.col_1 LIKE %(var)s".format(table_name, table_name), + parameters={"var": "you're"}, + ) rows = cursor.fetchall() assert rows[0]["col_1"] == "you're" def test_get_schemas(self): with self.cursor({}) as cursor: - database_name = 'db_{uuid}'.format(uuid=str(uuid4()).replace('-', '_')) + database_name = "db_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) try: - cursor.execute('CREATE DATABASE IF NOT EXISTS {}'.format(database_name)) + cursor.execute("CREATE DATABASE IF NOT EXISTS {}".format(database_name)) cursor.schemas() schemas = cursor.fetchall() schemas_desc = cursor.description # Catalogue name not consistent across DBR versions, so we skip that - self.assertIn(database_name, [schema[0] for schema in schemas]) - self.assertEqual(schemas_desc, - [('TABLE_SCHEM', 'string', None, None, None, None, None), - ('TABLE_CATALOG', 'string', None, None, None, None, None)]) + assert database_name in [schema[0] for schema in schemas] + assert schemas_desc == [ + ("TABLE_SCHEM", "string", None, None, None, None, None), + ("TABLE_CATALOG", "string", None, None, None, None, None), + ] + finally: - cursor.execute('DROP DATABASE IF EXISTS {}'.format(database_name)) + cursor.execute("DROP DATABASE IF EXISTS {}".format(database_name)) def test_get_catalogs(self): with self.cursor({}) as cursor: cursor.catalogs() cursor.fetchall() catalogs_desc = cursor.description - self.assertEqual(catalogs_desc, [('TABLE_CAT', 'string', None, None, None, None, None)]) + assert catalogs_desc == [("TABLE_CAT", "string", None, None, None, None, None)] - @skipUnless(pysql_supports_arrow(), 'arrow test need arrow support') + @skipUnless(pysql_supports_arrow(), "arrow test need arrow support") def test_get_arrow(self): # These tests are quite light weight as the arrow fetch methods are used internally # by everything else with self.cursor({}) as cursor: cursor.execute("SELECT * FROM range(10)") table_1 = cursor.fetchmany_arrow(1).to_pydict() - self.assertEqual(table_1, OrderedDict([("id", [0])])) + assert table_1 == OrderedDict([("id", [0])]) table_2 = cursor.fetchall_arrow().to_pydict() - self.assertEqual(table_2, OrderedDict([("id", [1, 2, 3, 4, 5, 6, 7, 8, 9])])) + assert table_2 == OrderedDict([("id", [1, 2, 3, 4, 5, 6, 7, 8, 9])]) def test_unicode(self): unicode_str = "数据砖" with self.cursor({}) as cursor: cursor.execute("SELECT '{}'".format(unicode_str)) results = cursor.fetchall() - self.assertTrue(len(results) == 1 and len(results[0]) == 1) - self.assertEqual(results[0][0], unicode_str) + assert len(results) == 1 and len(results[0]) == 1 + assert results[0][0] == unicode_str def test_cancel_during_execute(self): with self.cursor({}) as cursor: def execute_really_long_query(): - cursor.execute("SELECT SUM(A.id - B.id) " + - "FROM range(1000000000) A CROSS JOIN range(100000000) B " + - "GROUP BY (A.id - B.id)") + cursor.execute( + "SELECT SUM(A.id - B.id) " + + "FROM range(1000000000) A CROSS JOIN range(100000000) B " + + "GROUP BY (A.id - B.id)" + ) exec_thread = threading.Thread(target=execute_really_long_query) @@ -459,24 +466,24 @@ def execute_really_long_query(): time.sleep(15) cursor.cancel() exec_thread.join(5) - self.assertFalse(exec_thread.is_alive()) + assert not exec_thread.is_alive() # Fetching results should throw an exception - with self.assertRaises((Error, thrift.Thrift.TException)): + with pytest.raises((Error, thrift.Thrift.TException)): cursor.fetchall() - with self.assertRaises((Error, thrift.Thrift.TException)): + with pytest.raises((Error, thrift.Thrift.TException)): cursor.fetchone() - with self.assertRaises((Error, thrift.Thrift.TException)): + with pytest.raises((Error, thrift.Thrift.TException)): cursor.fetchmany(10) # We should be able to execute a new command on the cursor cursor.execute("SELECT * FROM range(3)") - self.assertEqual(len(cursor.fetchall()), 3) + assert len(cursor.fetchall()) == 3 - @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + @skipIf(pysql_has_version("<", "2"), "requires pysql v2") def test_can_execute_command_after_failure(self): with self.cursor({}) as cursor: - with self.assertRaises(DatabaseError): + with pytest.raises(DatabaseError): cursor.execute("this is a sytnax error") cursor.execute("SELECT 1;") @@ -484,7 +491,7 @@ def test_can_execute_command_after_failure(self): res = cursor.fetchall() self.assertEqualRowValues(res, [[1]]) - @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + @skipIf(pysql_has_version("<", "2"), "requires pysql v2") def test_can_execute_command_after_success(self): with self.cursor({}) as cursor: cursor.execute("SELECT 1;") @@ -497,19 +504,19 @@ def generate_multi_row_query(self): query = "SELECT * FROM range(3);" return query - @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + @skipIf(pysql_has_version("<", "2"), "requires pysql v2") def test_fetchone(self): with self.cursor({}) as cursor: query = self.generate_multi_row_query() cursor.execute(query) - self.assertSequenceEqual(cursor.fetchone(), [0]) - self.assertSequenceEqual(cursor.fetchone(), [1]) - self.assertSequenceEqual(cursor.fetchone(), [2]) + assert cursor.fetchone()[0] == 0 + assert cursor.fetchone()[0] == 1 + assert cursor.fetchone()[0] == 2 - self.assertEqual(cursor.fetchone(), None) + assert cursor.fetchone() == None - @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + @skipIf(pysql_has_version("<", "2"), "requires pysql v2") def test_fetchall(self): with self.cursor({}) as cursor: query = self.generate_multi_row_query() @@ -517,9 +524,9 @@ def test_fetchall(self): self.assertEqualRowValues(cursor.fetchall(), [[0], [1], [2]]) - self.assertEqual(cursor.fetchone(), None) + assert cursor.fetchone() == None - @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + @skipIf(pysql_has_version("<", "2"), "requires pysql v2") def test_fetchmany_when_stride_fits(self): with self.cursor({}) as cursor: query = "SELECT * FROM range(4)" @@ -528,7 +535,7 @@ def test_fetchmany_when_stride_fits(self): self.assertEqualRowValues(cursor.fetchmany(2), [[0], [1]]) self.assertEqualRowValues(cursor.fetchmany(2), [[2], [3]]) - @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + @skipIf(pysql_has_version("<", "2"), "requires pysql v2") def test_fetchmany_in_excess(self): with self.cursor({}) as cursor: query = "SELECT * FROM range(4)" @@ -537,15 +544,16 @@ def test_fetchmany_in_excess(self): self.assertEqualRowValues(cursor.fetchmany(3), [[0], [1], [2]]) self.assertEqualRowValues(cursor.fetchmany(3), [[3]]) - @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + @skipIf(pysql_has_version("<", "2"), "requires pysql v2") def test_iterator_api(self): with self.cursor({}) as cursor: query = "SELECT * FROM range(4)" cursor.execute(query) expected_results = [[0], [1], [2], [3]] - for (i, row) in enumerate(cursor): - self.assertSequenceEqual(row, expected_results[i]) + for i, row in enumerate(cursor): + for j in range(len(row)): + assert row[j] == expected_results[i][j] def test_temp_view_fetch(self): with self.cursor({}) as cursor: @@ -554,68 +562,72 @@ def test_temp_view_fetch(self): # TODO assert on a result # once what is being returned has stabilised - @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') - @skipIf(True, "Unclear the purpose of this test since urllib3 does not complain when timeout == 0") + @skipIf(pysql_has_version("<", "2"), "requires pysql v2") + @skipIf( + True, "Unclear the purpose of this test since urllib3 does not complain when timeout == 0" + ) def test_socket_timeout(self): # We expect to see a BlockingIO error when the socket is opened # in non-blocking mode, since no poll is done before the read - with self.assertRaises(OperationalError) as cm: + with pytest.raises(OperationalError) as cm: with self.cursor({"_socket_timeout": 0}): pass self.assertIsInstance(cm.exception.args[1], io.BlockingIOError) - @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + @skipIf(pysql_has_version("<", "2"), "requires pysql v2") + @skipIf(pysql_has_version(">", "2.8"), "This test has been broken for a while") def test_socket_timeout_user_defined(self): # We expect to see a TimeoutError when the socket timeout is only # 1 sec for a query that takes longer than that to process - with self.assertRaises(ReadTimeoutError) as cm: + with pytest.raises(ReadTimeoutError) as cm: with self.cursor({"_socket_timeout": 1}) as cursor: query = "select * from range(1000000000)" cursor.execute(query) - def test_ssp_passthrough(self): for enable_ansi in (True, False): with self.cursor({"session_configuration": {"ansi_mode": enable_ansi}}) as cursor: cursor.execute("SET ansi_mode") - self.assertEqual(list(cursor.fetchone()), ["ansi_mode", str(enable_ansi)]) + assert list(cursor.fetchone()) == ["ansi_mode", str(enable_ansi)] - @skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support') + @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") def test_timestamps_arrow(self): with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: - for (timestamp, expected) in self.timestamp_and_expected_results: + for timestamp, expected in self.timestamp_and_expected_results: cursor.execute("SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp)) arrow_table = cursor.fetchmany_arrow(1) if self.should_add_timezone(): ts_type = pyarrow.timestamp("us", tz="Etc/UTC") else: ts_type = pyarrow.timestamp("us") - self.assertEqual(arrow_table.field(0).type, ts_type) + assert arrow_table.field(0).type == ts_type result_value = arrow_table.column(0).combine_chunks()[0].value # To work consistently across different local timezones, we specify the timezone # of the expected result to # be UTC (what it should be by default on the server) aware_timestamp = expected and expected.replace(tzinfo=datetime.timezone.utc) - self.assertEqual(result_value, aware_timestamp and - aware_timestamp.timestamp() * 1000000, - "timestamp {} did not match {}".format(timestamp, expected)) + assert result_value == ( + aware_timestamp and aware_timestamp.timestamp() * 1000000 + ), "timestamp {} did not match {}".format(timestamp, expected) - @skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support') + @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") def test_multi_timestamps_arrow(self): with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: query, expected = self.multi_query() - expected = [[self.maybe_add_timezone_to_timestamp(ts) for ts in row] - for row in expected] + expected = [ + [self.maybe_add_timezone_to_timestamp(ts) for ts in row] for row in expected + ] cursor.execute(query) table = cursor.fetchall_arrow() # Transpose columnar result to list of rows list_of_cols = [c.to_pylist() for c in table] - result = [[col[row_index] for col in list_of_cols] - for row_index in range(table.num_rows)] - self.assertEqual(result, expected) + result = [ + [col[row_index] for col in list_of_cols] for row_index in range(table.num_rows) + ] + assert result == expected - @skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support') + @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") def test_timezone_with_timestamp(self): if self.should_add_timezone(): with self.cursor() as cursor: @@ -624,34 +636,34 @@ def test_timezone_with_timestamp(self): amsterdam = pytz.timezone("Europe/Amsterdam") expected = amsterdam.localize(datetime.datetime(2022, 3, 2, 12, 54, 56)) result = cursor.fetchone()[0] - self.assertEqual(result, expected) + assert result == expected cursor.execute("select CAST('2022-03-02 12:54:56' as TIMESTAMP)") arrow_result_table = cursor.fetchmany_arrow(1) arrow_result_value = arrow_result_table.column(0).combine_chunks()[0].value ts_type = pyarrow.timestamp("us", tz="Europe/Amsterdam") - self.assertEqual(arrow_result_table.field(0).type, ts_type) - self.assertEqual(arrow_result_value, expected.timestamp() * 1000000) + assert arrow_result_table.field(0).type == ts_type + assert arrow_result_value == expected.timestamp() * 1000000 - @skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support') + @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") def test_can_flip_compression(self): with self.cursor() as cursor: cursor.execute("SELECT array(1,2,3,4)") cursor.fetchall() lz4_compressed = cursor.active_result_set.lz4_compressed - #The endpoint should support compression - self.assertEqual(lz4_compressed, True) - cursor.connection.lz4_compression=False + # The endpoint should support compression + assert lz4_compressed + cursor.connection.lz4_compression = False cursor.execute("SELECT array(1,2,3,4)") cursor.fetchall() lz4_compressed = cursor.active_result_set.lz4_compressed - self.assertEqual(lz4_compressed, False) + assert not lz4_compressed def _should_have_native_complex_types(self): return pysql_has_version(">=", 2) and is_thrift_v5_plus(self.arguments) - @skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support') + @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") def test_arrays_are_not_returned_as_strings_arrow(self): if self._should_have_native_complex_types(): with self.cursor() as cursor: @@ -659,10 +671,10 @@ def test_arrays_are_not_returned_as_strings_arrow(self): arrow_df = cursor.fetchall_arrow() list_type = arrow_df.field(0).type - self.assertTrue(pyarrow.types.is_list(list_type)) - self.assertTrue(pyarrow.types.is_integer(list_type.value_type)) + assert pyarrow.types.is_list(list_type) + assert pyarrow.types.is_integer(list_type.value_type) - @skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support') + @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") def test_structs_are_not_returned_as_strings_arrow(self): if self._should_have_native_complex_types(): with self.cursor() as cursor: @@ -670,9 +682,9 @@ def test_structs_are_not_returned_as_strings_arrow(self): arrow_df = cursor.fetchall_arrow() struct_type = arrow_df.field(0).type - self.assertTrue(pyarrow.types.is_struct(struct_type)) + assert pyarrow.types.is_struct(struct_type) - @skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support') + @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") def test_decimal_not_returned_as_strings_arrow(self): if self._should_have_native_complex_types(): with self.cursor() as cursor: @@ -680,7 +692,7 @@ def test_decimal_not_returned_as_strings_arrow(self): arrow_df = cursor.fetchall_arrow() decimal_type = arrow_df.field(0).type - self.assertTrue(pyarrow.types.is_decimal(decimal_type)) + assert pyarrow.types.is_decimal(decimal_type) def test_close_connection_closes_cursors(self): @@ -688,82 +700,60 @@ def test_close_connection_closes_cursors(self): with self.connection() as conn: cursor = conn.cursor() - cursor.execute('SELECT id, id `id2`, id `id3` FROM RANGE(1000000) order by RANDOM()') + cursor.execute("SELECT id, id `id2`, id `id3` FROM RANGE(1000000) order by RANDOM()") ars = cursor.active_result_set # We must manually run this check because thrift_backend always forces `has_been_closed_server_side` to True # Cursor op state should be open before connection is closed - status_request = ttypes.TGetOperationStatusReq(operationHandle=ars.command_id, getProgressUpdate=False) + status_request = ttypes.TGetOperationStatusReq( + operationHandle=ars.command_id, getProgressUpdate=False + ) op_status_at_server = ars.thrift_backend._client.GetOperationStatus(status_request) assert op_status_at_server.operationState != ttypes.TOperationState.CLOSED_STATE conn.close() - + # When connection closes, any cursor operations should no longer exist at the server - with self.assertRaises(SessionAlreadyClosedError) as cm: + with pytest.raises(SessionAlreadyClosedError) as cm: op_status_at_server = ars.thrift_backend._client.GetOperationStatus(status_request) + def test_closing_a_closed_connection_doesnt_fail(self, caplog): + caplog.set_level(logging.DEBUG) + # Second .close() call is when this context manager exits + with self.connection() as conn: + # First .close() call is explicit here + conn.close() - - def test_closing_a_closed_connection_doesnt_fail(self): - - with self.assertLogs("databricks.sql", level="DEBUG",) as cm: - # Second .close() call is when this context manager exits - with self.connection() as conn: - # First .close() call is explicit here - conn.close() - - expected_message_was_found = False - for log in cm.output: - if expected_message_was_found: - break - target = "Session appears to have been closed already" - expected_message_was_found = target in log - - self.assertTrue(expected_message_was_found, "Did not find expected log messages") + assert "Session appears to have been closed already" in caplog.text # use a RetrySuite to encapsulate these tests which we'll typically want to run together; however keep # the 429/503 subsuites separate since they execute under different circumstances. -class PySQLRetryTestSuite: - class HTTP429Suite(Client429ResponseMixin, PySQLTestCase): +class TestPySQLRetrySuite: + class HTTP429Suite(Client429ResponseMixin, PySQLPytestTestCase): pass # Mixin covers all - class HTTP503Suite(Client503ResponseMixin, PySQLTestCase): + class HTTP503Suite(Client503ResponseMixin, PySQLPytestTestCase): # 503Response suite gets custom error here vs PyODBC def test_retry_disabled(self): self._test_retry_disabled_with_message("TEMPORARILY_UNAVAILABLE", OperationalError) -class PySQLUnityCatalogTestSuite(PySQLTestCase): +class TestPySQLUnityCatalogSuite(PySQLPytestTestCase): """Simple namespace tests that should be run against a unity-catalog-enabled cluster""" - @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + @skipIf(pysql_has_version("<", "2"), "requires pysql v2") def test_initial_namespace(self): - table_name = 'table_{uuid}'.format(uuid=str(uuid4()).replace('-', '_')) + table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) with self.cursor() as cursor: - cursor.execute("USE CATALOG {}".format(self.arguments["catA"])) + cursor.execute("USE CATALOG {}".format(self.arguments["catalog"])) cursor.execute("CREATE TABLE table_{} (col1 int)".format(table_name)) - with self.connection({ - "catalog": self.arguments["catA"], - "schema": table_name - }) as connection: + with self.connection( + {"catalog": self.arguments["catalog"], "schema": table_name} + ) as connection: cursor = connection.cursor() cursor.execute("select current_catalog()") - self.assertEqual(cursor.fetchone()[0], self.arguments["catA"]) + assert cursor.fetchone()[0] == self.arguments["catalog"] cursor.execute("select current_database()") - self.assertEqual(cursor.fetchone()[0], table_name) - - - -def main(cli_args): - global get_args_from_env - get_args_from_env = True - print(f"Running tests with version: {sql.__version__}") - logging.getLogger("databricks.sql").setLevel(logging.INFO) - unittest.main(module=__file__, argv=sys.argv[0:1] + cli_args) - - -if __name__ == "__main__": - main(sys.argv[1:]) + assert cursor.fetchone()[0] == table_name diff --git a/tests/e2e/test_parameterized_queries.py b/tests/e2e/test_parameterized_queries.py index 97901e9ce..47dfc38cc 100644 --- a/tests/e2e/test_parameterized_queries.py +++ b/tests/e2e/test_parameterized_queries.py @@ -51,8 +51,10 @@ class Primitive(Enum): FLOAT = 3.15 SMALLINT = 51 + class PrimitiveExtra(Enum): """These are not inferrable types. This Enum is used for parametrized tests.""" + TIMESTAMP_NTZ = datetime.datetime(2023, 9, 6, 3, 14, 27, 843) TINYINT = 20 @@ -108,7 +110,8 @@ def _get_inline_table_column(self, value): return self.inline_type_map[Primitive(value)] @pytest.fixture(scope="class") - def inline_table(self): + def inline_table(self, connection_details): + self.arguments = connection_details.copy() """This table is necessary to verify that a parameter sent with INLINE approach can actually write to its analogous data type. @@ -164,8 +167,12 @@ def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle): This is a no-op but is included to make the test-code easier to read. """ target_column = self._get_inline_table_column(params.get("p")) - INSERT_QUERY = f"INSERT INTO pysql_e2e_inline_param_test_table (`{target_column}`) VALUES (%(p)s)" - SELECT_QUERY = f"SELECT {target_column} `col` FROM pysql_e2e_inline_param_test_table LIMIT 1" + INSERT_QUERY = ( + f"INSERT INTO pysql_e2e_inline_param_test_table (`{target_column}`) VALUES (%(p)s)" + ) + SELECT_QUERY = ( + f"SELECT {target_column} `col` FROM pysql_e2e_inline_param_test_table LIMIT 1" + ) DELETE_QUERY = "DELETE FROM pysql_e2e_inline_param_test_table" with self.connection(extra_params={"use_inline_params": True}) as conn: @@ -274,7 +281,7 @@ def test_primitive_single( (Primitive.FLOAT, FloatParameter), (Primitive.SMALLINT, SmallIntParameter), (PrimitiveExtra.TIMESTAMP_NTZ, TimestampNTZParameter), - (PrimitiveExtra.TINYINT, TinyIntParameter) + (PrimitiveExtra.TINYINT, TinyIntParameter), ], ) def test_dbsqlparameter_single( @@ -301,15 +308,11 @@ def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog) If a user explicitly sets use_inline_params, don't warn them about it. """ - extra_args = ( - {"use_inline_params": use_inline_params} if use_inline_params else {} - ) + extra_args = {"use_inline_params": use_inline_params} if use_inline_params else {} with self.connection(extra_params=extra_args) as conn: with conn.cursor() as cursor: - with self.patch_server_supports_native_params( - supports_native_params=True - ): + with self.patch_server_supports_native_params(supports_native_params=True): cursor.execute("SELECT %(p)s", parameters={"p": 1}) if use_inline_params is True: assert ( @@ -351,8 +354,10 @@ def test_positional_native_multiple(self, params): def test_readme_example(self): with self.cursor() as cursor: - result = cursor.execute('SELECT :param `p`, * FROM RANGE(10)', {"param": "foo"}).fetchall() - + result = cursor.execute( + "SELECT :param `p`, * FROM RANGE(10)", {"param": "foo"} + ).fetchall() + assert len(result) == 10 assert result[0].p == "foo" @@ -397,9 +402,7 @@ def test_inline_ordinals_can_break_sql(self): query = "SELECT 'samsonite', %s WHERE 'samsonite' LIKE '%sonite'" params = ["luggage"] with self.cursor(extra_params={"use_inline_params": True}) as cursor: - with pytest.raises( - TypeError, match="not enough arguments for format string" - ): + with pytest.raises(TypeError, match="not enough arguments for format string"): cursor.execute(query, parameters=params) def test_inline_named_dont_break_sql(self): @@ -433,18 +436,18 @@ def test_inline_like_wildcard_breaks(self): a SQL LIKE wildcard %. This test proves that's the case. """ query = "SELECT 1 `col` WHERE 'foo' LIKE '%'" - params ={"param": 'bar'} + params = {"param": "bar"} with self.cursor(extra_params={"use_inline_params": True}) as cursor: with pytest.raises(ValueError, match="unsupported format character"): result = cursor.execute(query, parameters=params).fetchone() def test_native_like_wildcard_works(self): - """This is a mirror of test_inline_like_wildcard_breaks that proves that LIKE + """This is a mirror of test_inline_like_wildcard_breaks that proves that LIKE wildcards work under the native approach. """ query = "SELECT 1 `col` WHERE 'foo' LIKE '%'" - params ={"param": 'bar'} + params = {"param": "bar"} with self.cursor(extra_params={"use_inline_params": False}) as cursor: result = cursor.execute(query, parameters=params).fetchone() - + assert result.col == 1