diff --git a/aurora_data_api/__init__.py b/aurora_data_api/__init__.py index 948acd2..ce47ae6 100644 --- a/aurora_data_api/__init__.py +++ b/aurora_data_api/__init__.py @@ -12,7 +12,7 @@ Time, Timestamp, DateFromTicks, - # TimeFromTicks, + TimeFromTicks, TimestampFromTicks, Binary, STRING, @@ -22,3 +22,17 @@ ROWID, DECIMAL, ) +from .exceptions import ( + Warning, + Error, + InterfaceError, + DatabaseError, + DataError, + OperationalError, + IntegrityError, + InternalError, + ProgrammingError, + NotSupportedError, + MySQLError, + PostgreSQLError, +) diff --git a/aurora_data_api/async_.py b/aurora_data_api/async_.py index d681a4e..b1db2db 100644 --- a/aurora_data_api/async_.py +++ b/aurora_data_api/async_.py @@ -1,29 +1,41 @@ """ aurora-data-api - A Python DB-API 2.0 client for the AWS Aurora Serverless Data API (Async version) """ -import os, datetime, ipaddress, uuid, time, random, string, logging, itertools, reprlib, json, re, asyncio -from decimal import Decimal -from collections import namedtuple -from collections.abc import Mapping + +import time +import random +import string +import reprlib from .base import BaseAuroraDataAPIClient, BaseAuroraDataAPICursor, logger from .base import ( - apilevel, threadsafety, paramstyle, Date, Time, Timestamp, DateFromTicks, - TimestampFromTicks, Binary, STRING, BINARY, NUMBER, DATETIME, ROWID, DECIMAL, - ColumnDescription + apilevel, # noqa: F401 + threadsafety, # noqa: F401 + paramstyle, # noqa: F401 + Date, # noqa: F401 + Time, # noqa: F401 + Timestamp, # noqa: F401 + DateFromTicks, # noqa: F401 + TimeFromTicks, # noqa: F401 + TimestampFromTicks, # noqa: F401 + Binary, # noqa: F401 + STRING, # noqa: F401 + BINARY, # noqa: F401 + NUMBER, # noqa: F401 + DATETIME, # noqa: F401 + ROWID, # noqa: F401 + DECIMAL, # noqa: F401 ) from .exceptions import ( - Warning, - Error, - InterfaceError, - DatabaseError, - DataError, - OperationalError, - IntegrityError, - InternalError, - ProgrammingError, - NotSupportedError, - MySQLError, - PostgreSQLError, + Warning, # noqa: F401 + Error, # noqa: F401 + InterfaceError, # noqa: F401 + DatabaseError, # noqa: F401 + DataError, # noqa: F401 + OperationalError, # noqa: F401 + IntegrityError, # noqa: F401 + InternalError, # noqa: F401 + ProgrammingError, # noqa: F401 + NotSupportedError, # noqa: F401 ) import aiobotocore.session @@ -184,14 +196,14 @@ def __aiter__(self): async def __anext__(self): if self._paging_state: - if not hasattr(self, '_page_iterator'): + if not hasattr(self, "_page_iterator"): self._page_iterator = self._fetch_paginated_records() try: return await self._page_iterator.__anext__() except StopAsyncIteration: raise StopAsyncIteration else: - if not hasattr(self, '_record_index'): + if not hasattr(self, "_record_index"): self._record_index = 0 records = self._current_response.get("records", []) if self._record_index >= len(records): @@ -203,16 +215,16 @@ async def __anext__(self): async def _fetch_paginated_records(self): next_page_args = self._paging_state["execute_statement_args"] while True: - logger.debug( - "Fetching page of %d records for auto-paginated query", self._paging_state["records_per_page"] - ) + logger.debug("Fetching page of %d records for auto-paginated query", self._paging_state["records_per_page"]) next_page_args["sql"] = "FETCH {records_per_page} FROM {pg_cursor_name}".format(**self._paging_state) try: page = await self._client.execute_statement(**next_page_args) except self._client.exceptions.BadRequestException as e: cur_rpp = self._paging_state["records_per_page"] if "Database returned more than the allowed response size limit" in str(e) and cur_rpp > 1: - await self.scroll(-self._paging_state["records_per_page"]) # Rewind the cursor to read the page again + await self.scroll( + -self._paging_state["records_per_page"] + ) # Rewind the cursor to read the page again logger.debug("Halving records per page") self._paging_state["records_per_page"] //= 2 continue @@ -262,7 +274,7 @@ async def __aexit__(self, err_type, value, traceback): self._current_response = None -def connect( +async def connect( aurora_cluster_arn=None, secret_arn=None, rds_data_client=None, diff --git a/aurora_data_api/base.py b/aurora_data_api/base.py index 9d02649..279032f 100644 --- a/aurora_data_api/base.py +++ b/aurora_data_api/base.py @@ -1,26 +1,24 @@ """ Base classes for Aurora Data API clients and cursors """ -import os, datetime, ipaddress, uuid, time, random, string, logging, itertools, reprlib, json, re + +import os +import datetime +import ipaddress +import uuid +import time +import logging +import itertools +import re from decimal import Decimal from collections import namedtuple from collections.abc import Mapping from .exceptions import ( - Warning, - Error, - InterfaceError, DatabaseError, - DataError, - OperationalError, - IntegrityError, - InternalError, - ProgrammingError, NotSupportedError, MySQLError, PostgreSQLError, ) -from .error_codes_mysql import MySQLErrorCodes -from .error_codes_postgresql import PostgreSQLErrorCodes apilevel = "2.0" @@ -32,7 +30,12 @@ Time = datetime.time Timestamp = datetime.datetime DateFromTicks = datetime.date.fromtimestamp -# TimeFromTicks = datetime.time.fromtimestamp TODO + + +def TimeFromTicks(ticks): + return Time(*time.localtime(ticks)[3:6]) + + TimestampFromTicks = datetime.datetime.fromtimestamp Binary = bytes STRING = str @@ -50,7 +53,7 @@ class BaseAuroraDataAPIClient: """Base class for Aurora Data API clients""" - + def __init__( self, dbname=None, @@ -74,7 +77,7 @@ def close(self): class BaseAuroraDataAPICursor: """Base class for Aurora Data API cursors""" - + _pg_type_map = { "int": int, "int2": int, diff --git a/aurora_data_api/sync.py b/aurora_data_api/sync.py index 709f53c..56156a6 100644 --- a/aurora_data_api/sync.py +++ b/aurora_data_api/sync.py @@ -1,29 +1,42 @@ """ aurora-data-api - A Python DB-API 2.0 client for the AWS Aurora Serverless Data API """ -import os, datetime, ipaddress, uuid, time, random, string, logging, itertools, reprlib, json, re, threading -from decimal import Decimal -from collections import namedtuple -from collections.abc import Mapping + +import time +import random +import string +import reprlib +import threading from .base import BaseAuroraDataAPIClient, BaseAuroraDataAPICursor, logger from .base import ( - apilevel, threadsafety, paramstyle, Date, Time, Timestamp, DateFromTicks, - TimestampFromTicks, Binary, STRING, BINARY, NUMBER, DATETIME, ROWID, DECIMAL, - ColumnDescription + apilevel, # noqa: F401 + threadsafety, # noqa: F401 + paramstyle, # noqa: F401 + Date, # noqa: F401 + Time, # noqa: F401 + Timestamp, # noqa: F401 + DateFromTicks, # noqa: F401 + TimeFromTicks, # noqa: F401 + TimestampFromTicks, # noqa: F401 + Binary, # noqa: F401 + STRING, # noqa: F401 + BINARY, # noqa: F401 + NUMBER, # noqa: F401 + DATETIME, # noqa: F401 + ROWID, # noqa: F401 + DECIMAL, # noqa: F401 ) from .exceptions import ( - Warning, - Error, - InterfaceError, - DatabaseError, - DataError, - OperationalError, - IntegrityError, - InternalError, - ProgrammingError, - NotSupportedError, - MySQLError, - PostgreSQLError, + Warning, # noqa: F401 + Error, # noqa: F401 + InterfaceError, # noqa: F401 + DatabaseError, # noqa: F401 + DataError, # noqa: F401 + OperationalError, # noqa: F401 + IntegrityError, # noqa: F401 + InternalError, # noqa: F401 + ProgrammingError, # noqa: F401 + NotSupportedError, # noqa: F401 ) import botocore.session diff --git a/setup.py b/setup.py index 1981830..43aaeec 100755 --- a/setup.py +++ b/setup.py @@ -12,9 +12,7 @@ description="A Python DB-API 2.0 client for the AWS Aurora Serverless Data API", long_description=open("README.rst").read(), install_requires=["botocore >= 1.38.40, < 2"], - extras_require={ - "async": "aiobotocore >= 2.23.1, < 3" - }, + extras_require={"async": "aiobotocore >= 2.23.1, < 3"}, packages=find_packages(exclude=["test"]), platforms=["MacOS X", "Posix"], test_suite="test", diff --git a/test/base.py b/test/base.py new file mode 100644 index 0000000..594ed2b --- /dev/null +++ b/test/base.py @@ -0,0 +1,252 @@ +""" +Base test classes for Aurora Data API async and sync tests +""" + +import datetime +import decimal +import json +import logging +import os +import sys +import unittest +from abc import ABC, abstractmethod + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from aurora_data_api.error_codes_mysql import MySQLErrorCodes # noqa +from aurora_data_api.error_codes_postgresql import PostgreSQLErrorCodes # noqa + +logging.basicConfig(level=logging.INFO) +logging.getLogger("aurora_data_api").setLevel(logging.DEBUG) +logging.getLogger("urllib3.connectionpool").setLevel(logging.DEBUG) + + +# @unittest.skip +class CoreAuroraDataAPITest(unittest.TestCase): + """Base of base class to provide configuration for Aurora Data API tests""" + + @classmethod + def setUpClass(cls): + cls.db_name = os.environ.get("AURORA_DB_NAME", __name__) + cls.cluster_arn = os.environ.get("AURORA_CLUSTER_ARN") + cls.secret_arn = os.environ.get("SECRET_ARN") + + +# @unittest.skip +class BaseAuroraDataAPITest(CoreAuroraDataAPITest, ABC): + """Base test class for Aurora Data API tests (both sync and async)""" + + using_mysql = False + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._setup_test_data() + + @classmethod + @abstractmethod + def _setup_test_data(cls): + """Abstract method to set up test data - implemented differently for sync/async""" + pass + + @classmethod + @abstractmethod + def _teardown_test_data(cls): + """Abstract method to tear down test data - implemented differently for sync/async""" + pass + + @classmethod + def tearDownClass(cls): + cls._teardown_test_data() + + @abstractmethod + def test_invalid_statements(self): + """Test invalid SQL statements""" + pass + + @abstractmethod + def test_iterators(self): + """Test cursor iteration functionality""" + pass + + @abstractmethod + def test_postgres_exceptions(self): + """Test PostgreSQL-specific exceptions""" + pass + + @abstractmethod + def test_rowcount(self): + """Test rowcount functionality""" + pass + + @abstractmethod + def test_continue_after_timeout(self): + """Test continue after timeout functionality""" + pass + + def get_expected_row0(self): + """Get the expected first row for testing""" + return ( + 1, + "row0", + # Note: data api v1 used JSON serialization with extra whitespace; v2 uses compact serialization + datetime.date(2000, 1, 1) if self.using_mysql else '{"x":0,"y":"0","z":[0,0,1]}', + decimal.Decimal(0.0), + datetime.datetime(2020, 9, 17, 13, 49, 33) + if self.using_mysql + else datetime.datetime(2020, 9, 17, 13, 49, 32, 780180), + ) + + def get_test_data(self): + """Get test data for insertion""" + if self.using_mysql: + return [ + { + "name": "row{}".format(i), + "birthday": "2000-01-01", + "num": decimal.Decimal("%d.%d" % (i, i)), + "ts": "2020-09-17 13:49:32.780180", + } + for i in range(2048) + ] + else: + return [ + { + "name": "row{}".format(i), + # Note: data api v1 supports up to 512**512 but v2 only supports up to 128**128 + "doc": json.dumps({"x": i, "y": str(i), "z": [i, i * i, i**i if i < 128 else 0]}), + "num": decimal.Decimal("%d.%d" % (i, i)), + "ts": "2020-09-17 13:49:32.780180", + } + for i in range(2048) + ] + + def get_postgresql_create_table_sql(self): + """Get PostgreSQL CREATE TABLE statement""" + return """ + CREATE TABLE aurora_data_api_test ( + id SERIAL, + name TEXT, + doc JSONB DEFAULT '{}', + num NUMERIC (10, 5) DEFAULT 0.0, + ts TIMESTAMP WITHOUT TIME ZONE + ) + """ + + def get_mysql_create_table_sql(self): + """Get MySQL CREATE TABLE statement""" + return ( + "CREATE TABLE aurora_data_api_test (id SERIAL, name TEXT, birthday DATE, num NUMERIC(10, 5), ts TIMESTAMP)" + ) + + def get_postgresql_insert_sql(self): + """Get PostgreSQL INSERT statement""" + return """ + INSERT INTO aurora_data_api_test(name, doc, num, ts) + VALUES (:name, CAST(:doc AS JSONB), :num, CAST(:ts AS TIMESTAMP)) + """ + + def get_mysql_insert_sql(self): + """Get MySQL INSERT statement""" + return ( + "INSERT INTO aurora_data_api_test(name, birthday, num, ts) VALUES " + "(:name, :birthday, :num, CAST(:ts AS DATETIME))" + ) + + +class PEP249ConformanceTestMixin: + """ + A mixin containing a comprehensive set of formal conformance tests + for PEP 249 v2.0, based on the full specification. + Reference: https://peps.python.org/pep-0249/ + """ + + # Helper methods to perform assertions + def _assert_callable(self, obj, attr_name): + self.assertTrue(hasattr(obj, attr_name), f"'{attr_name}' must exist") + self.assertTrue(callable(getattr(obj, attr_name)), f"'{attr_name}' must be callable") + + def _assert_attribute(self, obj, attr_name): + self.assertTrue(hasattr(obj, attr_name), f"Attribute '{attr_name}' must exist") + + def _assert_optional_callable(self, obj, attr_name): + if hasattr(obj, attr_name): + self.assertTrue(callable(getattr(obj, attr_name)), f"If '{attr_name}' exists, it must be callable") + + # === [Module Interface] Tests === + + def test_module_globals_and_constructor(self): + """[Module] Tests for globals (apilevel, etc.) and the connect() constructor.""" + self._assert_callable(self.driver, "connect") + self._assert_attribute(self.driver, "apilevel") + self.assertEqual(self.driver.apilevel, "2.0", "apilevel must be '2.0' for this specification") + self._assert_attribute(self.driver, "threadsafety") + self.assertIn(self.driver.threadsafety, [0, 1, 2, 3]) + self._assert_attribute(self.driver, "paramstyle") + self.assertIn(self.driver.paramstyle, ["qmark", "numeric", "named", "format", "pyformat"]) + + def test_module_exceptions(self): + """[Module] Tests for the existence and hierarchy of all required exception classes.""" + exceptions = [ + "Warning", + "Error", + "InterfaceError", + "DatabaseError", + "DataError", + "OperationalError", + "IntegrityError", + "InternalError", + "ProgrammingError", + "NotSupportedError", + ] + for exc_name in exceptions: + self._assert_attribute(self.driver, exc_name) + self.assertTrue(issubclass(self.driver.Warning, Exception)) + self.assertTrue(issubclass(self.driver.Error, Exception)) + self.assertTrue(issubclass(self.driver.InterfaceError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.DatabaseError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.DataError, self.driver.DatabaseError)) + self.assertTrue(issubclass(self.driver.OperationalError, self.driver.DatabaseError)) + self.assertTrue(issubclass(self.driver.IntegrityError, self.driver.DatabaseError)) + self.assertTrue(issubclass(self.driver.InternalError, self.driver.DatabaseError)) + self.assertTrue(issubclass(self.driver.ProgrammingError, self.driver.DatabaseError)) + self.assertTrue(issubclass(self.driver.NotSupportedError, self.driver.DatabaseError)) + + def test_module_type_objects_and_constructors(self): + """[Module] Tests for the existence of all required Type Objects and Constructors.""" + for constructor in [ + "Date", + "Time", + "Timestamp", + "DateFromTicks", + "TimeFromTicks", + "TimestampFromTicks", + "Binary", + ]: + self._assert_callable(self.driver, constructor) + for type_obj in ["STRING", "BINARY", "NUMBER", "DATETIME", "ROWID"]: + self._assert_attribute(self.driver, type_obj) + + # === [Connection and Cursor] Tests and Assertion Helpers === + def test_connection_interface(self): + """[Connection] Tests to assert the complete interface of a Connection object.""" + self._assert_callable(self.connection, "close") + self._assert_callable(self.connection, "commit") + self._assert_callable(self.connection, "cursor") + self._assert_optional_callable(self.connection, "rollback") + + def _assert_cursor_interface(self, cursor): + """[Cursor] Helper to assert the complete interface of a Cursor object.""" + self._assert_attribute(cursor, "description") + self._assert_attribute(cursor, "rowcount") + self._assert_attribute(cursor, "arraysize") + self._assert_callable(cursor, "close") + self._assert_callable(cursor, "execute") + self._assert_callable(cursor, "executemany") + self._assert_callable(cursor, "fetchone") + self._assert_callable(cursor, "fetchall") + self._assert_callable(cursor, "fetchmany") + self._assert_callable(cursor, "setinputsizes") + self._assert_callable(cursor, "setoutputsize") + self._assert_optional_callable(cursor, "callproc") + self._assert_optional_callable(cursor, "nextset") diff --git a/test/test_async.py b/test/test_async.py index a244764..d0339da 100644 --- a/test/test_async.py +++ b/test/test_async.py @@ -1,5 +1,4 @@ import asyncio -import datetime import decimal import json import logging @@ -9,9 +8,11 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) -import aurora_data_api.async_ as async_ # noqa +import aurora_data_api.async_ as async_ # noqa +import aurora_data_api.exceptions as exceptions # noqa from aurora_data_api.error_codes_mysql import MySQLErrorCodes # noqa from aurora_data_api.error_codes_postgresql import PostgreSQLErrorCodes # noqa +from base import BaseAuroraDataAPITest, CoreAuroraDataAPITest, PEP249ConformanceTestMixin # noqa logging.basicConfig(level=logging.INFO) logging.getLogger("aurora_data_api").setLevel(logging.DEBUG) @@ -20,156 +21,138 @@ class AsyncTestCase(unittest.TestCase): """Base class for async test cases.""" - + def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) - + def tearDown(self): self.loop.close() - + def async_test(coro): """Decorator to run async test methods.""" + def wrapper(self): return self.loop.run_until_complete(coro(self)) + return wrapper -class TestAuroraDataAPI(AsyncTestCase): - using_mysql = False +class TestAuroraDataAPI(BaseAuroraDataAPITest, AsyncTestCase): + """Asynchronous Aurora Data API tests""" @classmethod - def setUpClass(cls): + def _setup_test_data(cls): + """Set up test data for async tests""" loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - loop.run_until_complete(cls._async_setUpClass()) + loop.run_until_complete(cls._async_setup_test_data()) finally: loop.close() - + @classmethod - async def _async_setUpClass(cls): - cls.db_name = os.environ.get("AURORA_DB_NAME", __name__) - cls.cluster_arn = os.environ.get("AURORA_CLUSTER_ARN") - cls.secret_arn = os.environ.get("SECRET_ARN") - async with await async_.connect(database=cls.db_name, aurora_cluster_arn=cls.cluster_arn, secret_arn=cls.secret_arn) as conn: + async def _async_setup_test_data(cls): + """Async helper for setting up test data""" + async with async_.connect( + database=cls.db_name, aurora_cluster_arn=cls.cluster_arn, secret_arn=cls.secret_arn + ) as conn: cur = await conn.cursor() try: await cur.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp"') await cur.execute("DROP TABLE IF EXISTS aurora_data_api_test") - await cur.execute( - """ - CREATE TABLE aurora_data_api_test ( - id SERIAL, - name TEXT, - doc JSONB DEFAULT '{}', - num NUMERIC (10, 5) DEFAULT 0.0, - ts TIMESTAMP WITHOUT TIME ZONE - ) - """ - ) + await cur.execute(cls().get_postgresql_create_table_sql()) await cur.executemany( - """ - INSERT INTO aurora_data_api_test(name, doc, num, ts) - VALUES (:name, CAST(:doc AS JSONB), :num, CAST(:ts AS TIMESTAMP)) - """, - [ - { - "name": "row{}".format(i), - # Note: data api v1 supports up to 512**512 but v2 only supports up to 128**128 - "doc": json.dumps({"x": i, "y": str(i), "z": [i, i * i, i**i if i < 128 else 0]}), - "num": decimal.Decimal("%d.%d" % (i, i)), - "ts": "2020-09-17 13:49:32.780180", - } - for i in range(2048) - ], + cls().get_postgresql_insert_sql(), + cls().get_test_data(), ) - except async_.MySQLError.ER_PARSE_ERROR: + except exceptions.MySQLError.ER_PARSE_ERROR: cls.using_mysql = True await cur.execute("DROP TABLE IF EXISTS aurora_data_api_test") - await cur.execute( - "CREATE TABLE aurora_data_api_test " - "(id SERIAL, name TEXT, birthday DATE, num NUMERIC(10, 5), ts TIMESTAMP)" - ) + await cur.execute(cls().get_mysql_create_table_sql()) await cur.executemany( - ( - "INSERT INTO aurora_data_api_test(name, birthday, num, ts) VALUES " - "(:name, :birthday, :num, CAST(:ts AS DATETIME))" - ), - [ - { - "name": "row{}".format(i), - "birthday": "2000-01-01", - "num": decimal.Decimal("%d.%d" % (i, i)), - "ts": "2020-09-17 13:49:32.780180", - } - for i in range(2048) - ], + cls().get_mysql_insert_sql(), + cls().get_test_data(), ) @classmethod - def tearDownClass(cls): + def _teardown_test_data(cls): + """Tear down test data for async tests""" loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - loop.run_until_complete(cls._async_tearDownClass()) + loop.run_until_complete(cls._async_teardown_test_data()) finally: loop.close() - + @classmethod - async def _async_tearDownClass(cls): - async with await async_.connect(database=cls.db_name, aurora_cluster_arn=cls.cluster_arn, secret_arn=cls.secret_arn) as conn: + async def _async_teardown_test_data(cls): + """Async helper for tearing down test data""" + async with async_.connect( + database=cls.db_name, aurora_cluster_arn=cls.cluster_arn, secret_arn=cls.secret_arn + ) as conn: cur = await conn.cursor() await cur.execute("DROP TABLE IF EXISTS aurora_data_api_test") @AsyncTestCase.async_test async def test_invalid_statements(self): - async with await async_.connect(database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn) as conn: + async with async_.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn: cur = await conn.cursor() - with self.assertRaises( - (async_.exceptions.PostgreSQLError.ER_SYNTAX_ERR, async_.MySQLError.ER_PARSE_ERROR) - ): + with self.assertRaises((exceptions.PostgreSQLError.ER_SYNTAX_ERR, exceptions.MySQLError.ER_PARSE_ERROR)): await cur.execute("selec * from table") @AsyncTestCase.async_test async def test_iterators(self): - async with await async_.connect(database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn) as conn: + """Test cursor iteration functionality""" + async with async_.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn: cur = await conn.cursor() if not self.using_mysql: - await cur.execute("select count(*) from aurora_data_api_test where pg_column_size(doc) < :s", dict(s=2**6)) + cur = await conn.cursor() + await cur.execute( + "select count(*) from aurora_data_api_test where pg_column_size(doc) < :s", dict(s=2**6) + ) result = await cur.fetchone() self.assertEqual(result[0], 0) - await cur.execute("select count(*) from aurora_data_api_test where pg_column_size(doc) < :s", dict(s=2**7)) + + cur = await conn.cursor() + await cur.execute( + "select count(*) from aurora_data_api_test where pg_column_size(doc) < :s", dict(s=2**7) + ) result = await cur.fetchone() self.assertEqual(result[0], 1977) - await cur.execute("select count(*) from aurora_data_api_test where pg_column_size(doc) < :s", dict(s=2**8)) + + cur = await conn.cursor() + await cur.execute( + "select count(*) from aurora_data_api_test where pg_column_size(doc) < :s", dict(s=2**8) + ) result = await cur.fetchone() self.assertEqual(result[0], 2048) - await cur.execute("select count(*) from aurora_data_api_test where pg_column_size(doc) < :s", dict(s=2**10)) + + cur = await conn.cursor() + await cur.execute( + "select count(*) from aurora_data_api_test where pg_column_size(doc) < :s", dict(s=2**10) + ) result = await cur.fetchone() self.assertEqual(result[0], 2048) - cursor = await conn.cursor() - expect_row0 = ( - 1, - "row0", - # Note: data api v1 used JSON serialization with extra whitespace; v2 uses compact serialization - datetime.date(2000, 1, 1) if self.using_mysql else '{"x":0,"y":"0","z":[0,0,1]}', - decimal.Decimal(0.0), - datetime.datetime(2020, 9, 17, 13, 49, 33) - if self.using_mysql - else datetime.datetime(2020, 9, 17, 13, 49, 32, 780180), - ) + cur = await conn.cursor() + # cursor = await conn.cursor() + expect_row0 = self.get_expected_row0() i = 0 - await cursor.execute("select * from aurora_data_api_test") - async for f in cursor: + await cur.execute("select * from aurora_data_api_test") + async for f in cur: if i == 0: self.assertEqual(f, expect_row0) i += 1 self.assertEqual(i, 2048) - await cursor.execute("select * from aurora_data_api_test") - data = await cursor.fetchall() + cur = await conn.cursor() + await cur.execute("select * from aurora_data_api_test") + data = await cur.fetchall() self.assertEqual(data[0], expect_row0) self.assertEqual(data[-1][0], 2048) self.assertEqual(data[-1][1], "row2047") @@ -177,20 +160,22 @@ async def test_iterators(self): self.assertEqual(json.loads(data[-1][2]), {"x": 2047, "y": str(2047), "z": [2047, 2047 * 2047, 0]}) self.assertEqual(data[-1][-2], decimal.Decimal("2047.2047")) self.assertEqual(len(data), 2048) - self.assertEqual(len(await cursor.fetchall()), 0) + self.assertEqual(len(await cur.fetchall()), 0) - await cursor.execute("select * from aurora_data_api_test") + cur = await conn.cursor() + await cur.execute("select * from aurora_data_api_test") i = 0 while True: - result = await cursor.fetchone() + result = await cur.fetchone() if not result: break i += 1 self.assertEqual(i, 2048) - await cursor.execute("select * from aurora_data_api_test") + cur = await conn.cursor() + await cur.execute("select * from aurora_data_api_test") while True: - fm = await cursor.fetchmany(1001) + fm = await cur.fetchmany(1001) if not fm: break self.assertIn(len(fm), [1001, 46]) @@ -203,7 +188,9 @@ async def test_iterators(self): async def test_pagination_backoff(self): if self.using_mysql: return - async with await async_.connect(database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn) as conn: + async with async_.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn: cur = await conn.cursor() sql_template = "select concat({}) from aurora_data_api_test" sql = sql_template.format(", ".join(["cast(doc as text)"] * 64)) @@ -214,19 +201,20 @@ async def test_pagination_backoff(self): concat_args = ", ".join(["cast(doc as text)"] * 100) sql = sql_template.format(", ".join("concat({})".format(concat_args) for i in range(32))) await cur.execute(sql) - with self.assertRaisesRegex( - Exception, "Database response exceeded size limit" - ): + with self.assertRaisesRegex(Exception, "Database response exceeded size limit"): await cur.fetchall() @AsyncTestCase.async_test async def test_postgres_exceptions(self): + """Test PostgreSQL-specific exceptions""" if self.using_mysql: return - async with await async_.connect(database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn) as conn: + async with async_.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn: cur = await conn.cursor() table = "aurora_data_api_nonexistent_test_table" - with self.assertRaises(async_.exceptions.PostgreSQLError.ER_UNDEF_TABLE) as e: + with self.assertRaises(exceptions.PostgreSQLError.ER_UNDEF_TABLE) as e: sql = f"select * from {table}" await cur.execute(sql) self.assertTrue(f'relation "{table}" does not exist' in str(e.exception)) @@ -234,12 +222,17 @@ async def test_postgres_exceptions(self): @AsyncTestCase.async_test async def test_rowcount(self): - async with await async_.connect(database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn) as conn: + """Test rowcount functionality""" + async with async_.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn: cur = await conn.cursor() await cur.execute("select * from aurora_data_api_test limit 8") self.assertEqual(cur.rowcount, 8) - async with await async_.connect(database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn) as conn: + async with async_.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn: cur = await conn.cursor() await cur.execute("select * from aurora_data_api_test limit 9000") self.assertEqual(cur.rowcount, 2048) @@ -247,7 +240,9 @@ async def test_rowcount(self): if self.using_mysql: return - async with await async_.connect(database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn) as conn: + async with async_.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn: cur = await conn.cursor() await cur.executemany( "INSERT INTO aurora_data_api_test(name, doc) VALUES (:name, CAST(:doc AS JSONB))", @@ -268,6 +263,7 @@ async def test_rowcount(self): @AsyncTestCase.async_test async def test_continue_after_timeout(self): + """Test continue after timeout functionality""" if os.environ.get("TEST_CONTINUE_AFTER_TIMEOUT", "False") != "True": self.skipTest("TEST_CONTINUE_AFTER_TIMEOUT env var is not 'True'") @@ -275,7 +271,9 @@ async def test_continue_after_timeout(self): self.skipTest("Not implemented for MySQL") try: - async with await async_.connect(database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn) as conn: + async with async_.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn: cur = await conn.cursor() with self.assertRaisesRegex(Exception, "StatementTimeoutException"): await cur.execute( @@ -287,15 +285,19 @@ async def test_continue_after_timeout(self): with self.assertRaisesRegex(async_.DatabaseError, "current transaction is aborted"): await cur.execute("SELECT COUNT(*) FROM aurora_data_api_test WHERE name = 'continue_after_timeout'") - async with await async_.connect(database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn) as conn: + async with async_.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn: cur = await conn.cursor() await cur.execute("SELECT COUNT(*) FROM aurora_data_api_test WHERE name = 'continue_after_timeout'") result = await cur.fetchone() self.assertEqual(result, (0,)) - async with await async_.connect( - database=self.db_name, continue_after_timeout=True, - aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + async with async_.connect( + database=self.db_name, + continue_after_timeout=True, + aurora_cluster_arn=self.cluster_arn, + secret_arn=self.secret_arn, ) as conn: cur = await conn.cursor() with self.assertRaisesRegex(Exception, "StatementTimeoutException"): @@ -309,10 +311,35 @@ async def test_continue_after_timeout(self): result = await cur.fetchone() self.assertEqual(result, (1,)) finally: - async with await async_.connect(database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn) as conn: + async with async_.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn: cur = await conn.cursor() await cur.execute("DELETE FROM aurora_data_api_test WHERE name = 'continue_after_timeout'") +class TestAuroraDataAPIConformance(PEP249ConformanceTestMixin, CoreAuroraDataAPITest): + """Conformance test class for asynchronous tests. Sets up connection to run tests.""" + + driver = async_ + connection = None + + def setUp(self): + """Setup mock objects for connection and cursor as member variables.""" + self.connection = self.driver.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) + + async def test_cursor_interface(self): + """[Cursor] Runs formal interface checks on the cursor.""" + cur = await self.connection.cursor() + self._assert_cursor_interface(cur) + + +# Remove these classes to avoid instantiation errors +del AsyncTestCase +del CoreAuroraDataAPITest +del BaseAuroraDataAPITest + if __name__ == "__main__": unittest.main() diff --git a/test/test_compat.py b/test/test_compat.py new file mode 100644 index 0000000..bf6f8c5 --- /dev/null +++ b/test/test_compat.py @@ -0,0 +1,38 @@ +import logging +import os +import sys +import unittest +import aurora_data_api +from base import CoreAuroraDataAPITest, PEP249ConformanceTestMixin + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +logging.basicConfig(level=logging.INFO) +logging.getLogger("aurora_data_api").setLevel(logging.DEBUG) +logging.getLogger("urllib3.connectionpool").setLevel(logging.DEBUG) + + +class TestAuroraDataAPIConformance(PEP249ConformanceTestMixin, CoreAuroraDataAPITest): + """Conformance test class for synchronous tests. Sets up connection to run tests.""" + + driver = aurora_data_api + connection = None + + def setUp(self): + """Setup mock objects for connection and cursor as member variables.""" + self.connection = self.driver.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) + + def test_cursor_interface(self): + """[Cursor] Runs formal interface checks on the cursor.""" + with self.connection.cursor() as cur: + self._assert_cursor_interface(cur) + + +# Remove these classes to avoid instantiation errors +del CoreAuroraDataAPITest + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_sync.py b/test/test_sync.py index e4a6715..1d4d077 100644 --- a/test/test_sync.py +++ b/test/test_sync.py @@ -1,4 +1,3 @@ -import datetime import decimal import json import logging @@ -9,90 +8,70 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) import aurora_data_api.sync as sync # noqa +import aurora_data_api.exceptions as exceptions # noqa from aurora_data_api.error_codes_mysql import MySQLErrorCodes # noqa from aurora_data_api.error_codes_postgresql import PostgreSQLErrorCodes # noqa +from base import BaseAuroraDataAPITest, CoreAuroraDataAPITest, PEP249ConformanceTestMixin # noqa logging.basicConfig(level=logging.INFO) logging.getLogger("aurora_data_api").setLevel(logging.DEBUG) logging.getLogger("urllib3.connectionpool").setLevel(logging.DEBUG) -class TestAuroraDataAPI(unittest.TestCase): - using_mysql = False +class TestAuroraDataAPI(BaseAuroraDataAPITest): + """Synchronous Aurora Data API tests""" @classmethod - def setUpClass(cls): - cls.db_name = os.environ.get("AURORA_DB_NAME", __name__) - cls.cluster_arn = os.environ.get("AURORA_CLUSTER_ARN") - cls.secret_arn = os.environ.get("SECRET_ARN") - with sync.connect(database=cls.db_name, aurora_cluster_arn=cls.cluster_arn, secret_arn=cls.secret_arn) as conn, conn.cursor() as cur: + def _setup_test_data(cls): + """Set up test data for sync tests""" + with ( + sync.connect(database=cls.db_name, aurora_cluster_arn=cls.cluster_arn, secret_arn=cls.secret_arn) as conn, + conn.cursor() as cur, + ): try: cur.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp"') cur.execute("DROP TABLE IF EXISTS aurora_data_api_test") - cur.execute( - """ - CREATE TABLE aurora_data_api_test ( - id SERIAL, - name TEXT, - doc JSONB DEFAULT '{}', - num NUMERIC (10, 5) DEFAULT 0.0, - ts TIMESTAMP WITHOUT TIME ZONE - ) - """ - ) + cur.execute(cls().get_postgresql_create_table_sql()) cur.executemany( - """ - INSERT INTO aurora_data_api_test(name, doc, num, ts) - VALUES (:name, CAST(:doc AS JSONB), :num, CAST(:ts AS TIMESTAMP)) - """, - [ - { - "name": "row{}".format(i), - # Note: data api v1 supports up to 512**512 but v2 only supports up to 128**128 - "doc": json.dumps({"x": i, "y": str(i), "z": [i, i * i, i**i if i < 128 else 0]}), - "num": decimal.Decimal("%d.%d" % (i, i)), - "ts": "2020-09-17 13:49:32.780180", - } - for i in range(2048) - ], + cls().get_postgresql_insert_sql(), + cls().get_test_data(), ) except sync.MySQLError.ER_PARSE_ERROR: cls.using_mysql = True cur.execute("DROP TABLE IF EXISTS aurora_data_api_test") - cur.execute( - "CREATE TABLE aurora_data_api_test " - "(id SERIAL, name TEXT, birthday DATE, num NUMERIC(10, 5), ts TIMESTAMP)" - ) + cur.execute(cls().get_mysql_create_table_sql()) cur.executemany( - ( - "INSERT INTO aurora_data_api_test(name, birthday, num, ts) VALUES " - "(:name, :birthday, :num, CAST(:ts AS DATETIME))" - ), - [ - { - "name": "row{}".format(i), - "birthday": "2000-01-01", - "num": decimal.Decimal("%d.%d" % (i, i)), - "ts": "2020-09-17 13:49:32.780180", - } - for i in range(2048) - ], + cls().get_mysql_insert_sql(), + cls().get_test_data(), ) @classmethod - def tearDownClass(cls): - with sync.connect(database=cls.db_name) as conn, conn.cursor() as cur: + def _teardown_test_data(cls): + """Tear down test data for sync tests""" + with ( + sync.connect(database=cls.db_name, aurora_cluster_arn=cls.cluster_arn, secret_arn=cls.secret_arn) as conn, + conn.cursor() as cur, + ): cur.execute("DROP TABLE IF EXISTS aurora_data_api_test") def test_invalid_statements(self): - with sync.connect(database=self.db_name) as conn, conn.cursor() as cur: - with self.assertRaises( - (sync.exceptions.PostgreSQLError.ER_SYNTAX_ERR, sync.MySQLError.ER_PARSE_ERROR) - ): + with ( + sync.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn, + conn.cursor() as cur, + ): + with self.assertRaises((exceptions.PostgreSQLError.ER_SYNTAX_ERR, exceptions.MySQLError.ER_PARSE_ERROR)): cur.execute("selec * from table") def test_iterators(self): - with sync.connect(database=self.db_name) as conn, conn.cursor() as cur: + """Test cursor iteration functionality""" + with ( + sync.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn, + conn.cursor() as cur, + ): if not self.using_mysql: cur.execute("select count(*) from aurora_data_api_test where pg_column_size(doc) < :s", dict(s=2**6)) self.assertEqual(cur.fetchone()[0], 0) @@ -104,16 +83,7 @@ def test_iterators(self): self.assertEqual(cur.fetchone()[0], 2048) with conn.cursor() as cursor: - expect_row0 = ( - 1, - "row0", - # Note: data api v1 used JSON serialization with extra whitespace; v2 uses compact serialization - datetime.date(2000, 1, 1) if self.using_mysql else '{"x":0,"y":"0","z":[0,0,1]}', - decimal.Decimal(0.0), - datetime.datetime(2020, 9, 17, 13, 49, 33) - if self.using_mysql - else datetime.datetime(2020, 9, 17, 13, 49, 32, 780180), - ) + expect_row0 = self.get_expected_row0() i = 0 cursor.execute("select * from aurora_data_api_test") for f in cursor: @@ -155,7 +125,12 @@ def test_iterators(self): def test_pagination_backoff(self): if self.using_mysql: return - with sync.connect(database=self.db_name) as conn, conn.cursor() as cur: + with ( + sync.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn, + conn.cursor() as cur, + ): sql_template = "select concat({}) from aurora_data_api_test" sql = sql_template.format(", ".join(["cast(doc as text)"] * 64)) cur.execute(sql) @@ -170,29 +145,51 @@ def test_pagination_backoff(self): cur.fetchall() def test_postgres_exceptions(self): + """Test PostgreSQL-specific exceptions""" if self.using_mysql: return - with sync.connect(database=self.db_name) as conn, conn.cursor() as cur: + with ( + sync.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn, + conn.cursor() as cur, + ): table = "aurora_data_api_nonexistent_test_table" - with self.assertRaises(sync.exceptions.PostgreSQLError.ER_UNDEF_TABLE) as e: + with self.assertRaises(exceptions.PostgreSQLError.ER_UNDEF_TABLE) as e: sql = f"select * from {table}" cur.execute(sql) self.assertTrue(f'relation "{table}" does not exist' in str(e.exception)) self.assertTrue(isinstance(e.exception.response, dict)) def test_rowcount(self): - with sync.connect(database=self.db_name) as conn, conn.cursor() as cur: + """Test rowcount functionality""" + with ( + sync.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn, + conn.cursor() as cur, + ): cur.execute("select * from aurora_data_api_test limit 8") self.assertEqual(cur.rowcount, 8) - with sync.connect(database=self.db_name) as conn, conn.cursor() as cur: + with ( + sync.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn, + conn.cursor() as cur, + ): cur.execute("select * from aurora_data_api_test limit 9000") self.assertEqual(cur.rowcount, 2048) if self.using_mysql: return - with sync.connect(database=self.db_name) as conn, conn.cursor() as cur: + with ( + sync.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn, + conn.cursor() as cur, + ): cur.executemany( "INSERT INTO aurora_data_api_test(name, doc) VALUES (:name, CAST(:doc AS JSONB))", [ @@ -211,6 +208,7 @@ def test_rowcount(self): self.assertEqual(cur.rowcount, 8) def test_continue_after_timeout(self): + """Test continue after timeout functionality""" if os.environ.get("TEST_CONTINUE_AFTER_TIMEOUT", "False") != "True": self.skipTest("TEST_CONTINUE_AFTER_TIMEOUT env var is not 'True'") @@ -218,7 +216,12 @@ def test_continue_after_timeout(self): self.skipTest("Not implemented for MySQL") try: - with sync.connect(database=self.db_name) as conn, conn.cursor() as cur: + with ( + sync.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn, + conn.cursor() as cur, + ): with self.assertRaisesRegex(conn._client.exceptions.ClientError, "StatementTimeoutException"): cur.execute( ( @@ -226,16 +229,19 @@ def test_continue_after_timeout(self): "FROM (SELECT pg_sleep(50)) q" ) ) - with self.assertRaisesRegex(sync.DatabaseError, "current transaction is aborted"): + with self.assertRaisesRegex(exceptions.DatabaseError, "current transaction is aborted"): cur.execute("SELECT COUNT(*) FROM aurora_data_api_test WHERE name = 'continue_after_timeout'") - with sync.connect(database=self.db_name) as conn, conn.cursor() as cur: + with ( + sync.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn, + conn.cursor() as cur, + ): cur.execute("SELECT COUNT(*) FROM aurora_data_api_test WHERE name = 'continue_after_timeout'") self.assertEqual(cur.fetchone(), (0,)) - with sync.connect( - database=self.db_name, continue_after_timeout=True - ) as conn, conn.cursor() as cur: + with sync.connect(database=self.db_name, continue_after_timeout=True) as conn, conn.cursor() as cur: with self.assertRaisesRegex(conn._client.exceptions.ClientError, "StatementTimeoutException"): cur.execute( ( @@ -246,9 +252,37 @@ def test_continue_after_timeout(self): cur.execute("SELECT COUNT(*) FROM aurora_data_api_test WHERE name = 'continue_after_timeout'") self.assertEqual(cur.fetchone(), (1,)) finally: - with sync.connect(database=self.db_name) as conn, conn.cursor() as cur: + with ( + sync.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn, + conn.cursor() as cur, + ): cur.execute("DELETE FROM aurora_data_api_test WHERE name = 'continue_after_timeout'") +class TestAuroraDataAPIConformance(PEP249ConformanceTestMixin, CoreAuroraDataAPITest): + """Conformance test class for synchronous tests. Sets up connection to run tests.""" + + driver = sync + connection = None + + def setUp(self): + """Setup mock objects for connection and cursor as member variables.""" + self.connection = self.driver.connect( + database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) + + def test_cursor_interface(self): + """[Cursor] Runs formal interface checks on the cursor.""" + with self.connection.cursor() as cur: + self._assert_cursor_interface(cur) + + +# Remove these classes to avoid instantiation errors +del CoreAuroraDataAPITest +del BaseAuroraDataAPITest + + if __name__ == "__main__": unittest.main()