From 329b75df5241fac68f49b8f3a67df7c386c90d32 Mon Sep 17 00:00:00 2001 From: tsuga <2888173+tsuga@users.noreply.github.com> Date: Sun, 20 Jul 2025 11:19:40 +0000 Subject: [PATCH 1/9] Rename error code files --- aurora_data_api/__init__.py | 4 ++-- .../{mysql_error_codes.py => error_codes_mysql.py} | 0 .../{postgresql_error_codes.py => error_codes_postgresql.py} | 0 aurora_data_api/exceptions.py | 4 ++-- test/test.py | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) rename aurora_data_api/{mysql_error_codes.py => error_codes_mysql.py} (100%) rename aurora_data_api/{postgresql_error_codes.py => error_codes_postgresql.py} (100%) diff --git a/aurora_data_api/__init__.py b/aurora_data_api/__init__.py index fbcf7bd..19538e3 100644 --- a/aurora_data_api/__init__.py +++ b/aurora_data_api/__init__.py @@ -19,8 +19,8 @@ MySQLError, PostgreSQLError, ) -from .mysql_error_codes import MySQLErrorCodes -from .postgresql_error_codes import PostgreSQLErrorCodes +from .error_codes_mysql import MySQLErrorCodes +from .error_codes_postgresql import PostgreSQLErrorCodes import boto3 apilevel = "2.0" diff --git a/aurora_data_api/mysql_error_codes.py b/aurora_data_api/error_codes_mysql.py similarity index 100% rename from aurora_data_api/mysql_error_codes.py rename to aurora_data_api/error_codes_mysql.py diff --git a/aurora_data_api/postgresql_error_codes.py b/aurora_data_api/error_codes_postgresql.py similarity index 100% rename from aurora_data_api/postgresql_error_codes.py rename to aurora_data_api/error_codes_postgresql.py diff --git a/aurora_data_api/exceptions.py b/aurora_data_api/exceptions.py index 9c01a2e..04ef1c8 100644 --- a/aurora_data_api/exceptions.py +++ b/aurora_data_api/exceptions.py @@ -1,5 +1,5 @@ -from .mysql_error_codes import MySQLErrorCodes -from .postgresql_error_codes import PostgreSQLErrorCodes +from .error_codes_mysql import MySQLErrorCodes +from .error_codes_postgresql import PostgreSQLErrorCodes class Warning(Exception): diff --git a/test/test.py b/test/test.py index 9b75bb0..7ce21ab 100644 --- a/test/test.py +++ b/test/test.py @@ -9,8 +9,8 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) import aurora_data_api # noqa -from aurora_data_api.mysql_error_codes import MySQLErrorCodes # noqa -from aurora_data_api.postgresql_error_codes import PostgreSQLErrorCodes # noqa +from aurora_data_api.error_codes_mysql import MySQLErrorCodes # noqa +from aurora_data_api.error_codes_postgresql import PostgreSQLErrorCodes # noqa logging.basicConfig(level=logging.INFO) logging.getLogger("aurora_data_api").setLevel(logging.DEBUG) From 28cc4bb874df3ec0b1020ef834934aa402173e8f Mon Sep 17 00:00:00 2001 From: tsuga <2888173+tsuga@users.noreply.github.com> Date: Sun, 20 Jul 2025 11:54:58 +0000 Subject: [PATCH 2/9] Change boto3 to botocore --- aurora_data_api/__init__.py | 5 +++-- setup.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/aurora_data_api/__init__.py b/aurora_data_api/__init__.py index 19538e3..43a23dc 100644 --- a/aurora_data_api/__init__.py +++ b/aurora_data_api/__init__.py @@ -21,7 +21,7 @@ ) from .error_codes_mysql import MySQLErrorCodes from .error_codes_postgresql import PostgreSQLErrorCodes -import boto3 +import botocore.session apilevel = "2.0" @@ -64,7 +64,8 @@ def __init__( self._client = rds_data_client if rds_data_client is None: with self._client_init_lock: - self._client = boto3.client("rds-data") + session = botocore.session.get_session() + self._client = session.create_client("rds-data") self._dbname = dbname self._aurora_cluster_arn = aurora_cluster_arn or os.environ.get("AURORA_CLUSTER_ARN") self._secret_arn = secret_arn or os.environ.get("AURORA_SECRET_ARN") diff --git a/setup.py b/setup.py index 03b75f8..5c5625a 100755 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ author_email="kislyuk@gmail.com", description="A Python DB-API 2.0 client for the AWS Aurora Serverless Data API", long_description=open("README.rst").read(), - install_requires=["boto3 >= 1.34.10, < 2"], + install_requires=["botocore >= 1.38.40, < 2"], extras_require={}, packages=find_packages(exclude=["test"]), platforms=["MacOS X", "Posix"], From db7f868d5da755506e62798006f6d5532ed94324 Mon Sep 17 00:00:00 2001 From: tsuga <2888173+tsuga@users.noreply.github.com> Date: Sun, 20 Jul 2025 12:12:26 +0000 Subject: [PATCH 3/9] Move existing code to sync module --- aurora_data_api/__init__,py | 4 +++ aurora_data_api/{__init__.py => sync.py} | 8 +++--- test/{test.py => test_sync.py} | 36 ++++++++++++------------ 3 files changed, 26 insertions(+), 22 deletions(-) create mode 100644 aurora_data_api/__init__,py rename aurora_data_api/{__init__.py => sync.py} (99%) rename test/{test.py => test_sync.py} (86%) diff --git a/aurora_data_api/__init__,py b/aurora_data_api/__init__,py new file mode 100644 index 0000000..4be7880 --- /dev/null +++ b/aurora_data_api/__init__,py @@ -0,0 +1,4 @@ +# For backward compatibility +from .sync import connect +from .sync import SyncAuroraDataAPIClient as AuroraDataAPIClient +from .sync import SyncAuroraDataAPICursor as AuroraDataAPICursor \ No newline at end of file diff --git a/aurora_data_api/__init__.py b/aurora_data_api/sync.py similarity index 99% rename from aurora_data_api/__init__.py rename to aurora_data_api/sync.py index 43a23dc..21e2128 100644 --- a/aurora_data_api/__init__.py +++ b/aurora_data_api/sync.py @@ -49,7 +49,7 @@ logger = logging.getLogger(__name__) -class AuroraDataAPIClient: +class SyncAuroraDataAPIClient: _client_init_lock = threading.Lock() def __init__( @@ -101,7 +101,7 @@ def cursor(self): secretArn=self._secret_arn, ) self._transaction_id = res["transactionId"] - cursor = AuroraDataAPICursor( + cursor = SyncAuroraDataAPICursor( client=self._client, dbname=self._dbname, aurora_cluster_arn=self._aurora_cluster_arn, @@ -123,7 +123,7 @@ def __exit__(self, err_type, value, traceback): self.commit() -class AuroraDataAPICursor: +class SyncAuroraDataAPICursor: _pg_type_map = { "int": int, "int2": int, @@ -445,7 +445,7 @@ def connect( charset=None, continue_after_timeout=None, ): - return AuroraDataAPIClient( + return SyncAuroraDataAPIClient( dbname=database, aurora_cluster_arn=aurora_cluster_arn, secret_arn=secret_arn, diff --git a/test/test.py b/test/test_sync.py similarity index 86% rename from test/test.py rename to test/test_sync.py index 7ce21ab..e3f88df 100644 --- a/test/test.py +++ b/test/test_sync.py @@ -8,7 +8,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) -import aurora_data_api # noqa +import aurora_data_api.sync as sync # noqa from aurora_data_api.error_codes_mysql import MySQLErrorCodes # noqa from aurora_data_api.error_codes_postgresql import PostgreSQLErrorCodes # noqa @@ -23,7 +23,7 @@ class TestAuroraDataAPI(unittest.TestCase): @classmethod def setUpClass(cls): cls.db_name = os.environ.get("AURORA_DB_NAME", __name__) - with aurora_data_api.connect(database=cls.db_name) as conn, conn.cursor() as cur: + with sync.connect(database=cls.db_name) as conn, conn.cursor() as cur: try: cur.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp"') cur.execute("DROP TABLE IF EXISTS aurora_data_api_test") @@ -54,7 +54,7 @@ def setUpClass(cls): for i in range(2048) ], ) - except aurora_data_api.MySQLError.ER_PARSE_ERROR: + except sync.MySQLError.ER_PARSE_ERROR: cls.using_mysql = True cur.execute("DROP TABLE IF EXISTS aurora_data_api_test") cur.execute( @@ -79,18 +79,18 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - with aurora_data_api.connect(database=cls.db_name) as conn, conn.cursor() as cur: + with sync.connect(database=cls.db_name) as conn, conn.cursor() as cur: cur.execute("DROP TABLE IF EXISTS aurora_data_api_test") def test_invalid_statements(self): - with aurora_data_api.connect(database=self.db_name) as conn, conn.cursor() as cur: + with sync.connect(database=self.db_name) as conn, conn.cursor() as cur: with self.assertRaises( - (aurora_data_api.exceptions.PostgreSQLError.ER_SYNTAX_ERR, aurora_data_api.MySQLError.ER_PARSE_ERROR) + (sync.exceptions.PostgreSQLError.ER_SYNTAX_ERR, sync.MySQLError.ER_PARSE_ERROR) ): cur.execute("selec * from table") def test_iterators(self): - with aurora_data_api.connect(database=self.db_name) as conn, conn.cursor() as cur: + with sync.connect(database=self.db_name) as conn, conn.cursor() as cur: if not self.using_mysql: cur.execute("select count(*) from aurora_data_api_test where pg_column_size(doc) < :s", dict(s=2**6)) self.assertEqual(cur.fetchone()[0], 0) @@ -153,7 +153,7 @@ def test_iterators(self): def test_pagination_backoff(self): if self.using_mysql: return - with aurora_data_api.connect(database=self.db_name) as conn, conn.cursor() as cur: + with sync.connect(database=self.db_name) as conn, conn.cursor() as cur: sql_template = "select concat({}) from aurora_data_api_test" sql = sql_template.format(", ".join(["cast(doc as text)"] * 64)) cur.execute(sql) @@ -170,27 +170,27 @@ def test_pagination_backoff(self): def test_postgres_exceptions(self): if self.using_mysql: return - with aurora_data_api.connect(database=self.db_name) as conn, conn.cursor() as cur: + with sync.connect(database=self.db_name) as conn, conn.cursor() as cur: table = "aurora_data_api_nonexistent_test_table" - with self.assertRaises(aurora_data_api.exceptions.PostgreSQLError.ER_UNDEF_TABLE) as e: + with self.assertRaises(sync.exceptions.PostgreSQLError.ER_UNDEF_TABLE) as e: sql = f"select * from {table}" cur.execute(sql) self.assertTrue(f'relation "{table}" does not exist' in str(e.exception)) self.assertTrue(isinstance(e.exception.response, dict)) def test_rowcount(self): - with aurora_data_api.connect(database=self.db_name) as conn, conn.cursor() as cur: + with sync.connect(database=self.db_name) as conn, conn.cursor() as cur: cur.execute("select * from aurora_data_api_test limit 8") self.assertEqual(cur.rowcount, 8) - with aurora_data_api.connect(database=self.db_name) as conn, conn.cursor() as cur: + with sync.connect(database=self.db_name) as conn, conn.cursor() as cur: cur.execute("select * from aurora_data_api_test limit 9000") self.assertEqual(cur.rowcount, 2048) if self.using_mysql: return - with aurora_data_api.connect(database=self.db_name) as conn, conn.cursor() as cur: + with sync.connect(database=self.db_name) as conn, conn.cursor() as cur: cur.executemany( "INSERT INTO aurora_data_api_test(name, doc) VALUES (:name, CAST(:doc AS JSONB))", [ @@ -216,7 +216,7 @@ def test_continue_after_timeout(self): self.skipTest("Not implemented for MySQL") try: - with aurora_data_api.connect(database=self.db_name) as conn, conn.cursor() as cur: + with sync.connect(database=self.db_name) as conn, conn.cursor() as cur: with self.assertRaisesRegex(conn._client.exceptions.ClientError, "StatementTimeoutException"): cur.execute( ( @@ -224,14 +224,14 @@ def test_continue_after_timeout(self): "FROM (SELECT pg_sleep(50)) q" ) ) - with self.assertRaisesRegex(aurora_data_api.DatabaseError, "current transaction is aborted"): + with self.assertRaisesRegex(sync.DatabaseError, "current transaction is aborted"): cur.execute("SELECT COUNT(*) FROM aurora_data_api_test WHERE name = 'continue_after_timeout'") - with aurora_data_api.connect(database=self.db_name) as conn, conn.cursor() as cur: + with sync.connect(database=self.db_name) as conn, conn.cursor() as cur: cur.execute("SELECT COUNT(*) FROM aurora_data_api_test WHERE name = 'continue_after_timeout'") self.assertEqual(cur.fetchone(), (0,)) - with aurora_data_api.connect( + with sync.connect( database=self.db_name, continue_after_timeout=True ) as conn, conn.cursor() as cur: with self.assertRaisesRegex(conn._client.exceptions.ClientError, "StatementTimeoutException"): @@ -244,7 +244,7 @@ def test_continue_after_timeout(self): cur.execute("SELECT COUNT(*) FROM aurora_data_api_test WHERE name = 'continue_after_timeout'") self.assertEqual(cur.fetchone(), (1,)) finally: - with aurora_data_api.connect(database=self.db_name) as conn, conn.cursor() as cur: + with sync.connect(database=self.db_name) as conn, conn.cursor() as cur: cur.execute("DELETE FROM aurora_data_api_test WHERE name = 'continue_after_timeout'") From c81c6df51c334f3ce5c000b14c2c349132e7bce0 Mon Sep 17 00:00:00 2001 From: tsuga <2888173+tsuga@users.noreply.github.com> Date: Sun, 20 Jul 2025 12:45:05 +0000 Subject: [PATCH 4/9] Enhance test code for sync --- test/test_sync.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_sync.py b/test/test_sync.py index e3f88df..e4a6715 100644 --- a/test/test_sync.py +++ b/test/test_sync.py @@ -23,7 +23,9 @@ class TestAuroraDataAPI(unittest.TestCase): @classmethod def setUpClass(cls): cls.db_name = os.environ.get("AURORA_DB_NAME", __name__) - with sync.connect(database=cls.db_name) as conn, conn.cursor() as cur: + cls.cluster_arn = os.environ.get("AURORA_CLUSTER_ARN") + cls.secret_arn = os.environ.get("SECRET_ARN") + with sync.connect(database=cls.db_name, aurora_cluster_arn=cls.cluster_arn, secret_arn=cls.secret_arn) as conn, conn.cursor() as cur: try: cur.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp"') cur.execute("DROP TABLE IF EXISTS aurora_data_api_test") From 15ffdaedda359387321fe813ec766dad2894d0a3 Mon Sep 17 00:00:00 2001 From: tsuga <2888173+tsuga@users.noreply.github.com> Date: Sun, 20 Jul 2025 12:45:18 +0000 Subject: [PATCH 5/9] Add async code --- aurora_data_api/async_.py | 487 ++++++++++++++++++++++++++++++++++++++ setup.py | 4 +- test/test_async.py | 318 +++++++++++++++++++++++++ 3 files changed, 808 insertions(+), 1 deletion(-) create mode 100644 aurora_data_api/async_.py create mode 100644 test/test_async.py diff --git a/aurora_data_api/async_.py b/aurora_data_api/async_.py new file mode 100644 index 0000000..4a2a56f --- /dev/null +++ b/aurora_data_api/async_.py @@ -0,0 +1,487 @@ +""" +aurora-data-api - A Python DB-API 2.0 client for the AWS Aurora Serverless Data API (Async version) +""" +import os, datetime, ipaddress, uuid, time, random, string, logging, itertools, reprlib, json, re, asyncio +from decimal import Decimal +from collections import namedtuple +from collections.abc import Mapping +from .exceptions import ( + Warning, + Error, + InterfaceError, + DatabaseError, + DataError, + OperationalError, + IntegrityError, + InternalError, + ProgrammingError, + NotSupportedError, + MySQLError, + PostgreSQLError, +) +from .error_codes_mysql import MySQLErrorCodes +from .error_codes_postgresql import PostgreSQLErrorCodes +import aiobotocore.session + +apilevel = "2.0" + +threadsafety = 0 + +paramstyle = "named" + +Date = datetime.date +Time = datetime.time +Timestamp = datetime.datetime +DateFromTicks = datetime.date.fromtimestamp +# TimeFromTicks = datetime.time.fromtimestamp TODO +TimestampFromTicks = datetime.datetime.fromtimestamp +Binary = bytes +STRING = str +BINARY = bytes +NUMBER = float +DATETIME = datetime.datetime +ROWID = str +DECIMAL = Decimal + +ColumnDescription = namedtuple("ColumnDescription", "name type_code display_size internal_size precision scale null_ok") +ColumnDescription.__new__.__defaults__ = (None,) * len(ColumnDescription._fields) + +logger = logging.getLogger(__name__) + + +class AsyncAuroraDataAPIClient: + def __init__( + self, + dbname=None, + aurora_cluster_arn=None, + secret_arn=None, + rds_data_client=None, + charset=None, + continue_after_timeout=None, + ): + self._client = rds_data_client + self._session = None + if rds_data_client is None: + self._session = aiobotocore.session.get_session() + self._client = None # Will be created when needed + self._dbname = dbname + self._aurora_cluster_arn = aurora_cluster_arn or os.environ.get("AURORA_CLUSTER_ARN") + self._secret_arn = secret_arn or os.environ.get("AURORA_SECRET_ARN") + self._charset = charset + self._transaction_id = None + self._continue_after_timeout = continue_after_timeout + self._client_context = None + + async def _ensure_client(self): + if self._client is None and self._session: + self._client_context = self._session.create_client("rds-data") + self._client = await self._client_context.__aenter__() + + async def close(self): + if self._client_context: + await self._client_context.__aexit__(None, None, None) + self._client_context = None + self._client = None + + async def commit(self): + if self._transaction_id: + await self._ensure_client() + res = await self._client.commit_transaction( + resourceArn=self._aurora_cluster_arn, secretArn=self._secret_arn, transactionId=self._transaction_id + ) + self._transaction_id = None + if res["transactionStatus"] != "Transaction Committed": + raise DatabaseError("Error while committing transaction: {}".format(res)) + + async def rollback(self): + if self._transaction_id: + await self._ensure_client() + await self._client.rollback_transaction( + resourceArn=self._aurora_cluster_arn, secretArn=self._secret_arn, transactionId=self._transaction_id + ) + self._transaction_id = None + + async def cursor(self): + if self._transaction_id is None: + await self._ensure_client() + res = await self._client.begin_transaction( + database=self._dbname, + resourceArn=self._aurora_cluster_arn, + # schema="string", TODO + secretArn=self._secret_arn, + ) + self._transaction_id = res["transactionId"] + cursor = AsyncAuroraDataAPICursor( + client=self._client, + dbname=self._dbname, + aurora_cluster_arn=self._aurora_cluster_arn, + secret_arn=self._secret_arn, + transaction_id=self._transaction_id, + continue_after_timeout=self._continue_after_timeout, + ) + if self._charset: + await cursor.execute("SET character_set_client = '{}'".format(self._charset)) + return cursor + + async def __aenter__(self): + await self._ensure_client() + return self + + async def __aexit__(self, err_type, value, traceback): + if err_type is not None: + await self.rollback() + else: + await self.commit() + await self.close() + + +class AsyncAuroraDataAPICursor: + _pg_type_map = { + "int": int, + "int2": int, + "int4": int, + "int8": int, + "float4": float, + "float8": float, + "serial2": int, + "serial4": int, + "serial8": int, + "bool": bool, + "varbit": bytes, + "bytea": bytearray, + "char": str, + "varchar": str, + "cidr": ipaddress.ip_network, + "date": datetime.date, + "inet": ipaddress.ip_address, + "json": dict, + "jsonb": dict, + "money": str, + "text": str, + "time": datetime.time, + "timestamp": datetime.datetime, + "uuid": uuid.uuid4, + "numeric": Decimal, + "decimal": Decimal, + } + _data_api_type_map = { + bytes: "blobValue", + bool: "booleanValue", + float: "doubleValue", + int: "longValue", + str: "stringValue", + Decimal: "stringValue", + # list: "arrayValue" + } + _data_api_type_hint_map = { + datetime.date: "DATE", + datetime.time: "TIME", + datetime.datetime: "TIMESTAMP", + Decimal: "DECIMAL", + } + + def __init__( + self, + client=None, + dbname=None, + aurora_cluster_arn=None, + secret_arn=None, + transaction_id=None, + continue_after_timeout=None, + ): + self.arraysize = 1000 + self.description = None + self._client = client + self._dbname = dbname + self._aurora_cluster_arn = aurora_cluster_arn + self._secret_arn = secret_arn + self._transaction_id = transaction_id + self._current_response = None + self._iterator = None + self._paging_state = None + self._continue_after_timeout = continue_after_timeout + + def prepare_param(self, param_name, param_value): + if param_value is None: + return dict(name=param_name, value=dict(isNull=True)) + param_data_api_type = self._data_api_type_map.get(type(param_value), "stringValue") + param = dict(name=param_name, value={param_data_api_type: param_value}) + if param_data_api_type == "stringValue" and not isinstance(param_value, str): + param["value"][param_data_api_type] = str(param_value) + if type(param_value) in self._data_api_type_hint_map: + param["typeHint"] = self._data_api_type_hint_map[type(param_value)] + return param + + # if param_data_api_type == "arrayValue" and len(param_value) > 0: + # return { + # param_data_api_type: { + # self._data_api_type_map.get(type(param_value[0]), "stringValue") + "s": param_value + # } + # } + + def _set_description(self, column_metadata): + # see https://www.postgresql.org/docs/9.5/datatype.html + self.description = [] + for column in column_metadata: + col_desc = ColumnDescription( + name=column["name"], type_code=self._pg_type_map.get(column["typeName"].lower(), str) + ) + self.description.append(col_desc) + + async def _start_paginated_query(self, execute_statement_args, records_per_page=None): + # MySQL cursors are non-scrollable (https://dev.mysql.com/doc/refman/8.0/en/cursors.html) + # - may not support page autosizing + # - FETCH requires INTO - may need to write all results into a server side var and iterate on it + pg_cursor_name = "{}_{}_{}".format( + __name__, int(time.time()), "".join(random.choices(string.ascii_letters + string.digits, k=8)) + ) + cursor_stmt = "DECLARE " + pg_cursor_name + " SCROLL CURSOR FOR " + execute_statement_args["sql"] = cursor_stmt + execute_statement_args["sql"] + await self._client.execute_statement(**execute_statement_args) + self._paging_state = { + "execute_statement_args": dict(execute_statement_args), + "records_per_page": records_per_page or self.arraysize, + "pg_cursor_name": pg_cursor_name, + } + + def _prepare_execute_args(self, operation): + execute_args = dict( + database=self._dbname, resourceArn=self._aurora_cluster_arn, secretArn=self._secret_arn, sql=operation + ) + if self._transaction_id: + execute_args["transactionId"] = self._transaction_id + return execute_args + + def _format_parameter_set(self, parameters): + if not isinstance(parameters, Mapping): + raise NotSupportedError("Expected a mapping of parameters. Array parameters are not supported.") + return [self.prepare_param(k, v) for k, v in parameters.items()] + + def _get_database_error(self, original_error): + error_msg = getattr(original_error, "response", {}).get("Error", {}).get("Message", "") + try: + res = re.search(r"Error code: (\d+); SQLState: (\d+)$", error_msg) + if res: # MySQL error + error_code = int(res.group(1)) + error_class = MySQLError.from_code(error_code) + error = error_class(error_msg) + error.response = getattr(original_error, "response", {}) + return error + res = re.search(r"ERROR: .*(?:\n |;) Position: (\d+); SQLState: (\w+)$", error_msg) + if res: # PostgreSQL error + error_code = res.group(2) + error_class = PostgreSQLError.from_code(error_code) + error = error_class(error_msg) + error.response = getattr(original_error, "response", {}) + return error + except Exception: + pass + return DatabaseError(original_error) + + async def execute(self, operation, parameters=None): + self._current_response, self._iterator, self._paging_state = None, None, None + execute_statement_args = dict(self._prepare_execute_args(operation), includeResultMetadata=True) + if self._continue_after_timeout is not None: + execute_statement_args["continueAfterTimeout"] = self._continue_after_timeout + if parameters: + execute_statement_args["parameters"] = self._format_parameter_set(parameters) + logger.debug("execute %s", reprlib.repr(operation.strip())) + try: + res = await self._client.execute_statement(**execute_statement_args) + if "columnMetadata" in res: + self._set_description(res["columnMetadata"]) + self._current_response = self._render_response(res) + except (self._client.exceptions.BadRequestException, self._client.exceptions.DatabaseErrorException) as e: + if "Please paginate your query" in str(e): + await self._start_paginated_query(execute_statement_args) + elif "Database returned more than the allowed response size limit" in str(e): + await self._start_paginated_query(execute_statement_args, records_per_page=max(1, self.arraysize // 2)) + else: + raise self._get_database_error(e) from e + self._iterator = self.__aiter__() + + @property + def rowcount(self): + if self._current_response: + if "records" in self._current_response: + return len(self._current_response["records"]) + elif "numberOfRecordsUpdated" in self._current_response: + return self._current_response["numberOfRecordsUpdated"] + return -1 + + @property + def lastrowid(self): + # TODO: this may not make sense if the previous statement is not an INSERT + if self._current_response and self._current_response.get("generatedFields"): + return self._render_value(self._current_response["generatedFields"][-1]) + + def _page_input(self, iterable, page_size=1000): + iterable = iter(iterable) + return iter(lambda: list(itertools.islice(iterable, page_size)), []) + + async def executemany(self, operation, seq_of_parameters): + logger.debug("executemany %s", reprlib.repr(operation.strip())) + for batch in self._page_input(seq_of_parameters): + batch_execute_statement_args = dict( + self._prepare_execute_args(operation), parameterSets=[self._format_parameter_set(p) for p in batch] + ) + try: + await self._client.batch_execute_statement(**batch_execute_statement_args) + except self._client.exceptions.BadRequestException as e: + raise self._get_database_error(e) from e + + def _render_response(self, response): + if "records" in response: + for i, record in enumerate(response["records"]): + response["records"][i] = tuple( + self._render_value(value, col_desc=self.description[j] if self.description else None) + for j, value in enumerate(record) + ) + return response + + def _render_value(self, value, col_desc=None): + if value.get("isNull"): + return None + elif "arrayValue" in value: + if "arrayValues" in value["arrayValue"]: + return [self._render_value(nested) for nested in value["arrayValue"]["arrayValues"]] + else: + return list(value["arrayValue"].values())[0] + else: + scalar_value = list(value.values())[0] + if col_desc and col_desc.type_code in self._data_api_type_hint_map: + if col_desc.type_code == Decimal: + scalar_value = Decimal(scalar_value) + else: + try: + scalar_value = col_desc.type_code.fromisoformat(scalar_value) + except (AttributeError, ValueError): # fromisoformat not supported on Python < 3.7 + if col_desc.type_code == datetime.date: + scalar_value = datetime.datetime.strptime(scalar_value, "%Y-%m-%d").date() + elif col_desc.type_code == datetime.time: + scalar_value = datetime.datetime.strptime(scalar_value, "%H:%M:%S").time() + elif "." in scalar_value: + scalar_value = datetime.datetime.strptime(scalar_value, "%Y-%m-%d %H:%M:%S.%f") + else: + scalar_value = datetime.datetime.strptime(scalar_value, "%Y-%m-%d %H:%M:%S") + return scalar_value + + async def scroll(self, value, mode="relative"): + if not self._paging_state: + raise InterfaceError("Cursor scroll attempted but pagination is not active") + scroll_stmt = "MOVE {mode} {value} FROM {pg_cursor_name}".format( + mode=mode.upper(), value=value, **self._paging_state + ) + scroll_args = dict(self._paging_state["execute_statement_args"], sql=scroll_stmt) + logger.debug("Scrolling cursor %s by %d rows", mode, value) + await self._client.execute_statement(**scroll_args) + + def __aiter__(self): + return self + + async def __anext__(self): + if self._paging_state: + if not hasattr(self, '_page_iterator'): + self._page_iterator = self._fetch_paginated_records() + try: + return await self._page_iterator.__anext__() + except StopAsyncIteration: + raise StopAsyncIteration + else: + if not hasattr(self, '_record_index'): + self._record_index = 0 + records = self._current_response.get("records", []) + if self._record_index >= len(records): + raise StopAsyncIteration + record = records[self._record_index] + self._record_index += 1 + return record + + async def _fetch_paginated_records(self): + next_page_args = self._paging_state["execute_statement_args"] + while True: + logger.debug( + "Fetching page of %d records for auto-paginated query", self._paging_state["records_per_page"] + ) + next_page_args["sql"] = "FETCH {records_per_page} FROM {pg_cursor_name}".format(**self._paging_state) + try: + page = await self._client.execute_statement(**next_page_args) + except self._client.exceptions.BadRequestException as e: + cur_rpp = self._paging_state["records_per_page"] + if "Database returned more than the allowed response size limit" in str(e) and cur_rpp > 1: + await self.scroll(-self._paging_state["records_per_page"]) # Rewind the cursor to read the page again + logger.debug("Halving records per page") + self._paging_state["records_per_page"] //= 2 + continue + else: + raise self._get_database_error(e) from e + + if "columnMetadata" in page and not self.description: + self._set_description(page["columnMetadata"]) + if len(page["records"]) == 0: + break + page = self._render_response(page) + for record in page["records"]: + yield record + + async def fetchone(self): + try: + return await self.__anext__() + except StopAsyncIteration: + return None + + async def fetchmany(self, size=None): + if size is None: + size = self.arraysize + results = [] + while size > 0: + result = await self.fetchone() + if result is None: + break + results.append(result) + size -= 1 + return results + + async def fetchall(self): + results = [] + async for record in self: + results.append(record) + return results + + def setinputsizes(self, sizes): + pass + + def setoutputsize(self, size, column=None): + pass + + async def close(self): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, err_type, value, traceback): + self._iterator = None + self._current_response = None + + +def connect( + aurora_cluster_arn=None, + secret_arn=None, + rds_data_client=None, + database=None, + host=None, + port=None, + username=None, + password=None, + charset=None, + continue_after_timeout=None, +): + return AsyncAuroraDataAPIClient( + dbname=database, + aurora_cluster_arn=aurora_cluster_arn, + secret_arn=secret_arn, + rds_data_client=rds_data_client, + charset=charset, + continue_after_timeout=continue_after_timeout, + ) diff --git a/setup.py b/setup.py index 5c5625a..1981830 100755 --- a/setup.py +++ b/setup.py @@ -12,7 +12,9 @@ description="A Python DB-API 2.0 client for the AWS Aurora Serverless Data API", long_description=open("README.rst").read(), install_requires=["botocore >= 1.38.40, < 2"], - extras_require={}, + extras_require={ + "async": "aiobotocore >= 2.23.1, < 3" + }, packages=find_packages(exclude=["test"]), platforms=["MacOS X", "Posix"], test_suite="test", diff --git a/test/test_async.py b/test/test_async.py new file mode 100644 index 0000000..a244764 --- /dev/null +++ b/test/test_async.py @@ -0,0 +1,318 @@ +import asyncio +import datetime +import decimal +import json +import logging +import os +import sys +import unittest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import aurora_data_api.async_ as async_ # noqa +from aurora_data_api.error_codes_mysql import MySQLErrorCodes # noqa +from aurora_data_api.error_codes_postgresql import PostgreSQLErrorCodes # noqa + +logging.basicConfig(level=logging.INFO) +logging.getLogger("aurora_data_api").setLevel(logging.DEBUG) +logging.getLogger("urllib3.connectionpool").setLevel(logging.DEBUG) + + +class AsyncTestCase(unittest.TestCase): + """Base class for async test cases.""" + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + def async_test(coro): + """Decorator to run async test methods.""" + def wrapper(self): + return self.loop.run_until_complete(coro(self)) + return wrapper + + +class TestAuroraDataAPI(AsyncTestCase): + using_mysql = False + + @classmethod + def setUpClass(cls): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(cls._async_setUpClass()) + finally: + loop.close() + + @classmethod + async def _async_setUpClass(cls): + cls.db_name = os.environ.get("AURORA_DB_NAME", __name__) + cls.cluster_arn = os.environ.get("AURORA_CLUSTER_ARN") + cls.secret_arn = os.environ.get("SECRET_ARN") + async with await async_.connect(database=cls.db_name, aurora_cluster_arn=cls.cluster_arn, secret_arn=cls.secret_arn) as conn: + cur = await conn.cursor() + try: + await cur.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp"') + await cur.execute("DROP TABLE IF EXISTS aurora_data_api_test") + await cur.execute( + """ + CREATE TABLE aurora_data_api_test ( + id SERIAL, + name TEXT, + doc JSONB DEFAULT '{}', + num NUMERIC (10, 5) DEFAULT 0.0, + ts TIMESTAMP WITHOUT TIME ZONE + ) + """ + ) + await cur.executemany( + """ + INSERT INTO aurora_data_api_test(name, doc, num, ts) + VALUES (:name, CAST(:doc AS JSONB), :num, CAST(:ts AS TIMESTAMP)) + """, + [ + { + "name": "row{}".format(i), + # Note: data api v1 supports up to 512**512 but v2 only supports up to 128**128 + "doc": json.dumps({"x": i, "y": str(i), "z": [i, i * i, i**i if i < 128 else 0]}), + "num": decimal.Decimal("%d.%d" % (i, i)), + "ts": "2020-09-17 13:49:32.780180", + } + for i in range(2048) + ], + ) + except async_.MySQLError.ER_PARSE_ERROR: + cls.using_mysql = True + await cur.execute("DROP TABLE IF EXISTS aurora_data_api_test") + await cur.execute( + "CREATE TABLE aurora_data_api_test " + "(id SERIAL, name TEXT, birthday DATE, num NUMERIC(10, 5), ts TIMESTAMP)" + ) + await cur.executemany( + ( + "INSERT INTO aurora_data_api_test(name, birthday, num, ts) VALUES " + "(:name, :birthday, :num, CAST(:ts AS DATETIME))" + ), + [ + { + "name": "row{}".format(i), + "birthday": "2000-01-01", + "num": decimal.Decimal("%d.%d" % (i, i)), + "ts": "2020-09-17 13:49:32.780180", + } + for i in range(2048) + ], + ) + + @classmethod + def tearDownClass(cls): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(cls._async_tearDownClass()) + finally: + loop.close() + + @classmethod + async def _async_tearDownClass(cls): + async with await async_.connect(database=cls.db_name, aurora_cluster_arn=cls.cluster_arn, secret_arn=cls.secret_arn) as conn: + cur = await conn.cursor() + await cur.execute("DROP TABLE IF EXISTS aurora_data_api_test") + + @AsyncTestCase.async_test + async def test_invalid_statements(self): + async with await async_.connect(database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn) as conn: + cur = await conn.cursor() + with self.assertRaises( + (async_.exceptions.PostgreSQLError.ER_SYNTAX_ERR, async_.MySQLError.ER_PARSE_ERROR) + ): + await cur.execute("selec * from table") + + @AsyncTestCase.async_test + async def test_iterators(self): + async with await async_.connect(database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn) as conn: + cur = await conn.cursor() + if not self.using_mysql: + await cur.execute("select count(*) from aurora_data_api_test where pg_column_size(doc) < :s", dict(s=2**6)) + result = await cur.fetchone() + self.assertEqual(result[0], 0) + await cur.execute("select count(*) from aurora_data_api_test where pg_column_size(doc) < :s", dict(s=2**7)) + result = await cur.fetchone() + self.assertEqual(result[0], 1977) + await cur.execute("select count(*) from aurora_data_api_test where pg_column_size(doc) < :s", dict(s=2**8)) + result = await cur.fetchone() + self.assertEqual(result[0], 2048) + await cur.execute("select count(*) from aurora_data_api_test where pg_column_size(doc) < :s", dict(s=2**10)) + result = await cur.fetchone() + self.assertEqual(result[0], 2048) + + cursor = await conn.cursor() + expect_row0 = ( + 1, + "row0", + # Note: data api v1 used JSON serialization with extra whitespace; v2 uses compact serialization + datetime.date(2000, 1, 1) if self.using_mysql else '{"x":0,"y":"0","z":[0,0,1]}', + decimal.Decimal(0.0), + datetime.datetime(2020, 9, 17, 13, 49, 33) + if self.using_mysql + else datetime.datetime(2020, 9, 17, 13, 49, 32, 780180), + ) + i = 0 + await cursor.execute("select * from aurora_data_api_test") + async for f in cursor: + if i == 0: + self.assertEqual(f, expect_row0) + i += 1 + self.assertEqual(i, 2048) + + await cursor.execute("select * from aurora_data_api_test") + data = await cursor.fetchall() + self.assertEqual(data[0], expect_row0) + self.assertEqual(data[-1][0], 2048) + self.assertEqual(data[-1][1], "row2047") + if not self.using_mysql: + self.assertEqual(json.loads(data[-1][2]), {"x": 2047, "y": str(2047), "z": [2047, 2047 * 2047, 0]}) + self.assertEqual(data[-1][-2], decimal.Decimal("2047.2047")) + self.assertEqual(len(data), 2048) + self.assertEqual(len(await cursor.fetchall()), 0) + + await cursor.execute("select * from aurora_data_api_test") + i = 0 + while True: + result = await cursor.fetchone() + if not result: + break + i += 1 + self.assertEqual(i, 2048) + + await cursor.execute("select * from aurora_data_api_test") + while True: + fm = await cursor.fetchmany(1001) + if not fm: + break + self.assertIn(len(fm), [1001, 46]) + + @unittest.skip( + "This test now fails because the API was changed to terminate and delete the transaction when the " + "data returned by the statement exceeds the limit, making automated recovery impossible." + ) + @AsyncTestCase.async_test + async def test_pagination_backoff(self): + if self.using_mysql: + return + async with await async_.connect(database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn) as conn: + cur = await conn.cursor() + sql_template = "select concat({}) from aurora_data_api_test" + sql = sql_template.format(", ".join(["cast(doc as text)"] * 64)) + await cur.execute(sql) + result = await cur.fetchall() + self.assertEqual(len(result), 2048) + + concat_args = ", ".join(["cast(doc as text)"] * 100) + sql = sql_template.format(", ".join("concat({})".format(concat_args) for i in range(32))) + await cur.execute(sql) + with self.assertRaisesRegex( + Exception, "Database response exceeded size limit" + ): + await cur.fetchall() + + @AsyncTestCase.async_test + async def test_postgres_exceptions(self): + if self.using_mysql: + return + async with await async_.connect(database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn) as conn: + cur = await conn.cursor() + table = "aurora_data_api_nonexistent_test_table" + with self.assertRaises(async_.exceptions.PostgreSQLError.ER_UNDEF_TABLE) as e: + sql = f"select * from {table}" + await cur.execute(sql) + self.assertTrue(f'relation "{table}" does not exist' in str(e.exception)) + self.assertTrue(isinstance(e.exception.response, dict)) + + @AsyncTestCase.async_test + async def test_rowcount(self): + async with await async_.connect(database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn) as conn: + cur = await conn.cursor() + await cur.execute("select * from aurora_data_api_test limit 8") + self.assertEqual(cur.rowcount, 8) + + async with await async_.connect(database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn) as conn: + cur = await conn.cursor() + await cur.execute("select * from aurora_data_api_test limit 9000") + self.assertEqual(cur.rowcount, 2048) + + if self.using_mysql: + return + + async with await async_.connect(database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn) as conn: + cur = await conn.cursor() + await cur.executemany( + "INSERT INTO aurora_data_api_test(name, doc) VALUES (:name, CAST(:doc AS JSONB))", + [ + { + "name": "rowcount{}".format(i), + "doc": json.dumps({"x": i, "y": str(i), "z": [i, i * i, i**i if i < 512 else 0]}), + } + for i in range(8) + ], + ) + + await cur.execute("UPDATE aurora_data_api_test SET doc = '{}' WHERE name like 'rowcount%'") + self.assertEqual(cur.rowcount, 8) + + await cur.execute("DELETE FROM aurora_data_api_test WHERE name like 'rowcount%'") + self.assertEqual(cur.rowcount, 8) + + @AsyncTestCase.async_test + async def test_continue_after_timeout(self): + if os.environ.get("TEST_CONTINUE_AFTER_TIMEOUT", "False") != "True": + self.skipTest("TEST_CONTINUE_AFTER_TIMEOUT env var is not 'True'") + + if self.using_mysql: + self.skipTest("Not implemented for MySQL") + + try: + async with await async_.connect(database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn) as conn: + cur = await conn.cursor() + with self.assertRaisesRegex(Exception, "StatementTimeoutException"): + await cur.execute( + ( + "INSERT INTO aurora_data_api_test(name) SELECT 'continue_after_timeout'" + "FROM (SELECT pg_sleep(50)) q" + ) + ) + with self.assertRaisesRegex(async_.DatabaseError, "current transaction is aborted"): + await cur.execute("SELECT COUNT(*) FROM aurora_data_api_test WHERE name = 'continue_after_timeout'") + + async with await async_.connect(database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn) as conn: + cur = await conn.cursor() + await cur.execute("SELECT COUNT(*) FROM aurora_data_api_test WHERE name = 'continue_after_timeout'") + result = await cur.fetchone() + self.assertEqual(result, (0,)) + + async with await async_.connect( + database=self.db_name, continue_after_timeout=True, + aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn + ) as conn: + cur = await conn.cursor() + with self.assertRaisesRegex(Exception, "StatementTimeoutException"): + await cur.execute( + ( + "INSERT INTO aurora_data_api_test(name) SELECT 'continue_after_timeout' " + "FROM (SELECT pg_sleep(50)) q" + ) + ) + await cur.execute("SELECT COUNT(*) FROM aurora_data_api_test WHERE name = 'continue_after_timeout'") + result = await cur.fetchone() + self.assertEqual(result, (1,)) + finally: + async with await async_.connect(database=self.db_name, aurora_cluster_arn=self.cluster_arn, secret_arn=self.secret_arn) as conn: + cur = await conn.cursor() + await cur.execute("DELETE FROM aurora_data_api_test WHERE name = 'continue_after_timeout'") + + +if __name__ == "__main__": + unittest.main() From 21686fdb287e39e457000cac6691c7b3a3b3caf9 Mon Sep 17 00:00:00 2001 From: tsuga <2888173+tsuga@users.noreply.github.com> Date: Sun, 27 Jul 2025 12:42:03 +0000 Subject: [PATCH 6/9] Fix filename typo --- aurora_data_api/__init__,py | 4 ---- aurora_data_api/__init__.py | 24 ++++++++++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) delete mode 100644 aurora_data_api/__init__,py create mode 100644 aurora_data_api/__init__.py diff --git a/aurora_data_api/__init__,py b/aurora_data_api/__init__,py deleted file mode 100644 index 4be7880..0000000 --- a/aurora_data_api/__init__,py +++ /dev/null @@ -1,4 +0,0 @@ -# For backward compatibility -from .sync import connect -from .sync import SyncAuroraDataAPIClient as AuroraDataAPIClient -from .sync import SyncAuroraDataAPICursor as AuroraDataAPICursor \ No newline at end of file diff --git a/aurora_data_api/__init__.py b/aurora_data_api/__init__.py new file mode 100644 index 0000000..948acd2 --- /dev/null +++ b/aurora_data_api/__init__.py @@ -0,0 +1,24 @@ +# For backward compatibility +from .sync import connect +from .sync import SyncAuroraDataAPIClient as AuroraDataAPIClient +from .sync import SyncAuroraDataAPICursor as AuroraDataAPICursor + +# Import all DB-API constants from base +from .base import ( + apilevel, + threadsafety, + paramstyle, + Date, + Time, + Timestamp, + DateFromTicks, + # TimeFromTicks, + TimestampFromTicks, + Binary, + STRING, + BINARY, + NUMBER, + DATETIME, + ROWID, + DECIMAL, +) From 25cd5e2a52202cc77514693b377d219f95fdd442 Mon Sep 17 00:00:00 2001 From: tsuga <2888173+tsuga@users.noreply.github.com> Date: Sun, 27 Jul 2025 12:42:42 +0000 Subject: [PATCH 7/9] Move common code to base.py --- aurora_data_api/async_.py | 223 ++------------------------------ aurora_data_api/base.py | 266 ++++++++++++++++++++++++++++++++++++++ aurora_data_api/sync.py | 229 ++------------------------------ 3 files changed, 286 insertions(+), 432 deletions(-) create mode 100644 aurora_data_api/base.py diff --git a/aurora_data_api/async_.py b/aurora_data_api/async_.py index 4a2a56f..d681a4e 100644 --- a/aurora_data_api/async_.py +++ b/aurora_data_api/async_.py @@ -5,6 +5,12 @@ from decimal import Decimal from collections import namedtuple from collections.abc import Mapping +from .base import BaseAuroraDataAPIClient, BaseAuroraDataAPICursor, logger +from .base import ( + apilevel, threadsafety, paramstyle, Date, Time, Timestamp, DateFromTicks, + TimestampFromTicks, Binary, STRING, BINARY, NUMBER, DATETIME, ROWID, DECIMAL, + ColumnDescription +) from .exceptions import ( Warning, Error, @@ -19,37 +25,10 @@ MySQLError, PostgreSQLError, ) -from .error_codes_mysql import MySQLErrorCodes -from .error_codes_postgresql import PostgreSQLErrorCodes import aiobotocore.session -apilevel = "2.0" - -threadsafety = 0 - -paramstyle = "named" - -Date = datetime.date -Time = datetime.time -Timestamp = datetime.datetime -DateFromTicks = datetime.date.fromtimestamp -# TimeFromTicks = datetime.time.fromtimestamp TODO -TimestampFromTicks = datetime.datetime.fromtimestamp -Binary = bytes -STRING = str -BINARY = bytes -NUMBER = float -DATETIME = datetime.datetime -ROWID = str -DECIMAL = Decimal - -ColumnDescription = namedtuple("ColumnDescription", "name type_code display_size internal_size precision scale null_ok") -ColumnDescription.__new__.__defaults__ = (None,) * len(ColumnDescription._fields) -logger = logging.getLogger(__name__) - - -class AsyncAuroraDataAPIClient: +class AsyncAuroraDataAPIClient(BaseAuroraDataAPIClient): def __init__( self, dbname=None, @@ -59,17 +38,11 @@ def __init__( charset=None, continue_after_timeout=None, ): - self._client = rds_data_client + super().__init__(dbname, aurora_cluster_arn, secret_arn, rds_data_client, charset, continue_after_timeout) self._session = None if rds_data_client is None: self._session = aiobotocore.session.get_session() self._client = None # Will be created when needed - self._dbname = dbname - self._aurora_cluster_arn = aurora_cluster_arn or os.environ.get("AURORA_CLUSTER_ARN") - self._secret_arn = secret_arn or os.environ.get("AURORA_SECRET_ARN") - self._charset = charset - self._transaction_id = None - self._continue_after_timeout = continue_after_timeout self._client_context = None async def _ensure_client(self): @@ -135,51 +108,7 @@ async def __aexit__(self, err_type, value, traceback): await self.close() -class AsyncAuroraDataAPICursor: - _pg_type_map = { - "int": int, - "int2": int, - "int4": int, - "int8": int, - "float4": float, - "float8": float, - "serial2": int, - "serial4": int, - "serial8": int, - "bool": bool, - "varbit": bytes, - "bytea": bytearray, - "char": str, - "varchar": str, - "cidr": ipaddress.ip_network, - "date": datetime.date, - "inet": ipaddress.ip_address, - "json": dict, - "jsonb": dict, - "money": str, - "text": str, - "time": datetime.time, - "timestamp": datetime.datetime, - "uuid": uuid.uuid4, - "numeric": Decimal, - "decimal": Decimal, - } - _data_api_type_map = { - bytes: "blobValue", - bool: "booleanValue", - float: "doubleValue", - int: "longValue", - str: "stringValue", - Decimal: "stringValue", - # list: "arrayValue" - } - _data_api_type_hint_map = { - datetime.date: "DATE", - datetime.time: "TIME", - datetime.datetime: "TIMESTAMP", - Decimal: "DECIMAL", - } - +class AsyncAuroraDataAPICursor(BaseAuroraDataAPICursor): def __init__( self, client=None, @@ -189,44 +118,7 @@ def __init__( transaction_id=None, continue_after_timeout=None, ): - self.arraysize = 1000 - self.description = None - self._client = client - self._dbname = dbname - self._aurora_cluster_arn = aurora_cluster_arn - self._secret_arn = secret_arn - self._transaction_id = transaction_id - self._current_response = None - self._iterator = None - self._paging_state = None - self._continue_after_timeout = continue_after_timeout - - def prepare_param(self, param_name, param_value): - if param_value is None: - return dict(name=param_name, value=dict(isNull=True)) - param_data_api_type = self._data_api_type_map.get(type(param_value), "stringValue") - param = dict(name=param_name, value={param_data_api_type: param_value}) - if param_data_api_type == "stringValue" and not isinstance(param_value, str): - param["value"][param_data_api_type] = str(param_value) - if type(param_value) in self._data_api_type_hint_map: - param["typeHint"] = self._data_api_type_hint_map[type(param_value)] - return param - - # if param_data_api_type == "arrayValue" and len(param_value) > 0: - # return { - # param_data_api_type: { - # self._data_api_type_map.get(type(param_value[0]), "stringValue") + "s": param_value - # } - # } - - def _set_description(self, column_metadata): - # see https://www.postgresql.org/docs/9.5/datatype.html - self.description = [] - for column in column_metadata: - col_desc = ColumnDescription( - name=column["name"], type_code=self._pg_type_map.get(column["typeName"].lower(), str) - ) - self.description.append(col_desc) + super().__init__(client, dbname, aurora_cluster_arn, secret_arn, transaction_id, continue_after_timeout) async def _start_paginated_query(self, execute_statement_args, records_per_page=None): # MySQL cursors are non-scrollable (https://dev.mysql.com/doc/refman/8.0/en/cursors.html) @@ -244,40 +136,6 @@ async def _start_paginated_query(self, execute_statement_args, records_per_page= "pg_cursor_name": pg_cursor_name, } - def _prepare_execute_args(self, operation): - execute_args = dict( - database=self._dbname, resourceArn=self._aurora_cluster_arn, secretArn=self._secret_arn, sql=operation - ) - if self._transaction_id: - execute_args["transactionId"] = self._transaction_id - return execute_args - - def _format_parameter_set(self, parameters): - if not isinstance(parameters, Mapping): - raise NotSupportedError("Expected a mapping of parameters. Array parameters are not supported.") - return [self.prepare_param(k, v) for k, v in parameters.items()] - - def _get_database_error(self, original_error): - error_msg = getattr(original_error, "response", {}).get("Error", {}).get("Message", "") - try: - res = re.search(r"Error code: (\d+); SQLState: (\d+)$", error_msg) - if res: # MySQL error - error_code = int(res.group(1)) - error_class = MySQLError.from_code(error_code) - error = error_class(error_msg) - error.response = getattr(original_error, "response", {}) - return error - res = re.search(r"ERROR: .*(?:\n |;) Position: (\d+); SQLState: (\w+)$", error_msg) - if res: # PostgreSQL error - error_code = res.group(2) - error_class = PostgreSQLError.from_code(error_code) - error = error_class(error_msg) - error.response = getattr(original_error, "response", {}) - return error - except Exception: - pass - return DatabaseError(original_error) - async def execute(self, operation, parameters=None): self._current_response, self._iterator, self._paging_state = None, None, None execute_statement_args = dict(self._prepare_execute_args(operation), includeResultMetadata=True) @@ -300,25 +158,6 @@ async def execute(self, operation, parameters=None): raise self._get_database_error(e) from e self._iterator = self.__aiter__() - @property - def rowcount(self): - if self._current_response: - if "records" in self._current_response: - return len(self._current_response["records"]) - elif "numberOfRecordsUpdated" in self._current_response: - return self._current_response["numberOfRecordsUpdated"] - return -1 - - @property - def lastrowid(self): - # TODO: this may not make sense if the previous statement is not an INSERT - if self._current_response and self._current_response.get("generatedFields"): - return self._render_value(self._current_response["generatedFields"][-1]) - - def _page_input(self, iterable, page_size=1000): - iterable = iter(iterable) - return iter(lambda: list(itertools.islice(iterable, page_size)), []) - async def executemany(self, operation, seq_of_parameters): logger.debug("executemany %s", reprlib.repr(operation.strip())) for batch in self._page_input(seq_of_parameters): @@ -330,42 +169,6 @@ async def executemany(self, operation, seq_of_parameters): except self._client.exceptions.BadRequestException as e: raise self._get_database_error(e) from e - def _render_response(self, response): - if "records" in response: - for i, record in enumerate(response["records"]): - response["records"][i] = tuple( - self._render_value(value, col_desc=self.description[j] if self.description else None) - for j, value in enumerate(record) - ) - return response - - def _render_value(self, value, col_desc=None): - if value.get("isNull"): - return None - elif "arrayValue" in value: - if "arrayValues" in value["arrayValue"]: - return [self._render_value(nested) for nested in value["arrayValue"]["arrayValues"]] - else: - return list(value["arrayValue"].values())[0] - else: - scalar_value = list(value.values())[0] - if col_desc and col_desc.type_code in self._data_api_type_hint_map: - if col_desc.type_code == Decimal: - scalar_value = Decimal(scalar_value) - else: - try: - scalar_value = col_desc.type_code.fromisoformat(scalar_value) - except (AttributeError, ValueError): # fromisoformat not supported on Python < 3.7 - if col_desc.type_code == datetime.date: - scalar_value = datetime.datetime.strptime(scalar_value, "%Y-%m-%d").date() - elif col_desc.type_code == datetime.time: - scalar_value = datetime.datetime.strptime(scalar_value, "%H:%M:%S").time() - elif "." in scalar_value: - scalar_value = datetime.datetime.strptime(scalar_value, "%Y-%m-%d %H:%M:%S.%f") - else: - scalar_value = datetime.datetime.strptime(scalar_value, "%Y-%m-%d %H:%M:%S") - return scalar_value - async def scroll(self, value, mode="relative"): if not self._paging_state: raise InterfaceError("Cursor scroll attempted but pagination is not active") @@ -448,12 +251,6 @@ async def fetchall(self): results.append(record) return results - def setinputsizes(self, sizes): - pass - - def setoutputsize(self, size, column=None): - pass - async def close(self): pass diff --git a/aurora_data_api/base.py b/aurora_data_api/base.py new file mode 100644 index 0000000..9d02649 --- /dev/null +++ b/aurora_data_api/base.py @@ -0,0 +1,266 @@ +""" +Base classes for Aurora Data API clients and cursors +""" +import os, datetime, ipaddress, uuid, time, random, string, logging, itertools, reprlib, json, re +from decimal import Decimal +from collections import namedtuple +from collections.abc import Mapping +from .exceptions import ( + Warning, + Error, + InterfaceError, + DatabaseError, + DataError, + OperationalError, + IntegrityError, + InternalError, + ProgrammingError, + NotSupportedError, + MySQLError, + PostgreSQLError, +) +from .error_codes_mysql import MySQLErrorCodes +from .error_codes_postgresql import PostgreSQLErrorCodes + +apilevel = "2.0" + +threadsafety = 0 + +paramstyle = "named" + +Date = datetime.date +Time = datetime.time +Timestamp = datetime.datetime +DateFromTicks = datetime.date.fromtimestamp +# TimeFromTicks = datetime.time.fromtimestamp TODO +TimestampFromTicks = datetime.datetime.fromtimestamp +Binary = bytes +STRING = str +BINARY = bytes +NUMBER = float +DATETIME = datetime.datetime +ROWID = str +DECIMAL = Decimal + +ColumnDescription = namedtuple("ColumnDescription", "name type_code display_size internal_size precision scale null_ok") +ColumnDescription.__new__.__defaults__ = (None,) * len(ColumnDescription._fields) + +logger = logging.getLogger(__name__) + + +class BaseAuroraDataAPIClient: + """Base class for Aurora Data API clients""" + + def __init__( + self, + dbname=None, + aurora_cluster_arn=None, + secret_arn=None, + rds_data_client=None, + charset=None, + continue_after_timeout=None, + ): + self._client = rds_data_client + self._dbname = dbname + self._aurora_cluster_arn = aurora_cluster_arn or os.environ.get("AURORA_CLUSTER_ARN") + self._secret_arn = secret_arn or os.environ.get("AURORA_SECRET_ARN") + self._charset = charset + self._transaction_id = None + self._continue_after_timeout = continue_after_timeout + + def close(self): + pass + + +class BaseAuroraDataAPICursor: + """Base class for Aurora Data API cursors""" + + _pg_type_map = { + "int": int, + "int2": int, + "int4": int, + "int8": int, + "float4": float, + "float8": float, + "serial2": int, + "serial4": int, + "serial8": int, + "bool": bool, + "varbit": bytes, + "bytea": bytearray, + "char": str, + "varchar": str, + "cidr": ipaddress.ip_network, + "date": datetime.date, + "inet": ipaddress.ip_address, + "json": dict, + "jsonb": dict, + "money": str, + "text": str, + "time": datetime.time, + "timestamp": datetime.datetime, + "uuid": uuid.uuid4, + "numeric": Decimal, + "decimal": Decimal, + } + _data_api_type_map = { + bytes: "blobValue", + bool: "booleanValue", + float: "doubleValue", + int: "longValue", + str: "stringValue", + Decimal: "stringValue", + # list: "arrayValue" + } + _data_api_type_hint_map = { + datetime.date: "DATE", + datetime.time: "TIME", + datetime.datetime: "TIMESTAMP", + Decimal: "DECIMAL", + } + + def __init__( + self, + client=None, + dbname=None, + aurora_cluster_arn=None, + secret_arn=None, + transaction_id=None, + continue_after_timeout=None, + ): + self.arraysize = 1000 + self.description = None + self._client = client + self._dbname = dbname + self._aurora_cluster_arn = aurora_cluster_arn + self._secret_arn = secret_arn + self._transaction_id = transaction_id + self._current_response = None + self._iterator = None + self._paging_state = None + self._continue_after_timeout = continue_after_timeout + + def prepare_param(self, param_name, param_value): + if param_value is None: + return dict(name=param_name, value=dict(isNull=True)) + param_data_api_type = self._data_api_type_map.get(type(param_value), "stringValue") + param = dict(name=param_name, value={param_data_api_type: param_value}) + if param_data_api_type == "stringValue" and not isinstance(param_value, str): + param["value"][param_data_api_type] = str(param_value) + if type(param_value) in self._data_api_type_hint_map: + param["typeHint"] = self._data_api_type_hint_map[type(param_value)] + return param + + # if param_data_api_type == "arrayValue" and len(param_value) > 0: + # return { + # param_data_api_type: { + # self._data_api_type_map.get(type(param_value[0]), "stringValue") + "s": param_value + # } + # } + + def _set_description(self, column_metadata): + # see https://www.postgresql.org/docs/9.5/datatype.html + self.description = [] + for column in column_metadata: + col_desc = ColumnDescription( + name=column["name"], type_code=self._pg_type_map.get(column["typeName"].lower(), str) + ) + self.description.append(col_desc) + + def _prepare_execute_args(self, operation): + execute_args = dict( + database=self._dbname, resourceArn=self._aurora_cluster_arn, secretArn=self._secret_arn, sql=operation + ) + if self._transaction_id: + execute_args["transactionId"] = self._transaction_id + return execute_args + + def _format_parameter_set(self, parameters): + if not isinstance(parameters, Mapping): + raise NotSupportedError("Expected a mapping of parameters. Array parameters are not supported.") + return [self.prepare_param(k, v) for k, v in parameters.items()] + + def _get_database_error(self, original_error): + error_msg = getattr(original_error, "response", {}).get("Error", {}).get("Message", "") + try: + res = re.search(r"Error code: (\d+); SQLState: (\d+)$", error_msg) + if res: # MySQL error + error_code = int(res.group(1)) + error_class = MySQLError.from_code(error_code) + error = error_class(error_msg) + error.response = getattr(original_error, "response", {}) + return error + res = re.search(r"ERROR: .*(?:\n |;) Position: (\d+); SQLState: (\w+)$", error_msg) + if res: # PostgreSQL error + error_code = res.group(2) + error_class = PostgreSQLError.from_code(error_code) + error = error_class(error_msg) + error.response = getattr(original_error, "response", {}) + return error + except Exception: + pass + return DatabaseError(original_error) + + @property + def rowcount(self): + if self._current_response: + if "records" in self._current_response: + return len(self._current_response["records"]) + elif "numberOfRecordsUpdated" in self._current_response: + return self._current_response["numberOfRecordsUpdated"] + return -1 + + @property + def lastrowid(self): + # TODO: this may not make sense if the previous statement is not an INSERT + if self._current_response and self._current_response.get("generatedFields"): + return self._render_value(self._current_response["generatedFields"][-1]) + + def _page_input(self, iterable, page_size=1000): + iterable = iter(iterable) + return iter(lambda: list(itertools.islice(iterable, page_size)), []) + + def _render_response(self, response): + if "records" in response: + for i, record in enumerate(response["records"]): + response["records"][i] = tuple( + self._render_value(value, col_desc=self.description[j] if self.description else None) + for j, value in enumerate(record) + ) + return response + + def _render_value(self, value, col_desc=None): + if value.get("isNull"): + return None + elif "arrayValue" in value: + if "arrayValues" in value["arrayValue"]: + return [self._render_value(nested) for nested in value["arrayValue"]["arrayValues"]] + else: + return list(value["arrayValue"].values())[0] + else: + scalar_value = list(value.values())[0] + if col_desc and col_desc.type_code in self._data_api_type_hint_map: + if col_desc.type_code == Decimal: + scalar_value = Decimal(scalar_value) + else: + try: + scalar_value = col_desc.type_code.fromisoformat(scalar_value) + except (AttributeError, ValueError): # fromisoformat not supported on Python < 3.7 + if col_desc.type_code == datetime.date: + scalar_value = datetime.datetime.strptime(scalar_value, "%Y-%m-%d").date() + elif col_desc.type_code == datetime.time: + scalar_value = datetime.datetime.strptime(scalar_value, "%H:%M:%S").time() + elif "." in scalar_value: + scalar_value = datetime.datetime.strptime(scalar_value, "%Y-%m-%d %H:%M:%S.%f") + else: + scalar_value = datetime.datetime.strptime(scalar_value, "%Y-%m-%d %H:%M:%S") + return scalar_value + + def setinputsizes(self, sizes): + pass + + def setoutputsize(self, size, column=None): + pass + + def close(self): + pass diff --git a/aurora_data_api/sync.py b/aurora_data_api/sync.py index 21e2128..709f53c 100644 --- a/aurora_data_api/sync.py +++ b/aurora_data_api/sync.py @@ -5,6 +5,12 @@ from decimal import Decimal from collections import namedtuple from collections.abc import Mapping +from .base import BaseAuroraDataAPIClient, BaseAuroraDataAPICursor, logger +from .base import ( + apilevel, threadsafety, paramstyle, Date, Time, Timestamp, DateFromTicks, + TimestampFromTicks, Binary, STRING, BINARY, NUMBER, DATETIME, ROWID, DECIMAL, + ColumnDescription +) from .exceptions import ( Warning, Error, @@ -19,37 +25,10 @@ MySQLError, PostgreSQLError, ) -from .error_codes_mysql import MySQLErrorCodes -from .error_codes_postgresql import PostgreSQLErrorCodes import botocore.session -apilevel = "2.0" - -threadsafety = 0 - -paramstyle = "named" - -Date = datetime.date -Time = datetime.time -Timestamp = datetime.datetime -DateFromTicks = datetime.date.fromtimestamp -# TimeFromTicks = datetime.time.fromtimestamp TODO -TimestampFromTicks = datetime.datetime.fromtimestamp -Binary = bytes -STRING = str -BINARY = bytes -NUMBER = float -DATETIME = datetime.datetime -ROWID = str -DECIMAL = Decimal - -ColumnDescription = namedtuple("ColumnDescription", "name type_code display_size internal_size precision scale null_ok") -ColumnDescription.__new__.__defaults__ = (None,) * len(ColumnDescription._fields) -logger = logging.getLogger(__name__) - - -class SyncAuroraDataAPIClient: +class SyncAuroraDataAPIClient(BaseAuroraDataAPIClient): _client_init_lock = threading.Lock() def __init__( @@ -61,20 +40,11 @@ def __init__( charset=None, continue_after_timeout=None, ): - self._client = rds_data_client + super().__init__(dbname, aurora_cluster_arn, secret_arn, rds_data_client, charset, continue_after_timeout) if rds_data_client is None: with self._client_init_lock: session = botocore.session.get_session() self._client = session.create_client("rds-data") - self._dbname = dbname - self._aurora_cluster_arn = aurora_cluster_arn or os.environ.get("AURORA_CLUSTER_ARN") - self._secret_arn = secret_arn or os.environ.get("AURORA_SECRET_ARN") - self._charset = charset - self._transaction_id = None - self._continue_after_timeout = continue_after_timeout - - def close(self): - pass def commit(self): if self._transaction_id: @@ -123,51 +93,7 @@ def __exit__(self, err_type, value, traceback): self.commit() -class SyncAuroraDataAPICursor: - _pg_type_map = { - "int": int, - "int2": int, - "int4": int, - "int8": int, - "float4": float, - "float8": float, - "serial2": int, - "serial4": int, - "serial8": int, - "bool": bool, - "varbit": bytes, - "bytea": bytearray, - "char": str, - "varchar": str, - "cidr": ipaddress.ip_network, - "date": datetime.date, - "inet": ipaddress.ip_address, - "json": dict, - "jsonb": dict, - "money": str, - "text": str, - "time": datetime.time, - "timestamp": datetime.datetime, - "uuid": uuid.uuid4, - "numeric": Decimal, - "decimal": Decimal, - } - _data_api_type_map = { - bytes: "blobValue", - bool: "booleanValue", - float: "doubleValue", - int: "longValue", - str: "stringValue", - Decimal: "stringValue", - # list: "arrayValue" - } - _data_api_type_hint_map = { - datetime.date: "DATE", - datetime.time: "TIME", - datetime.datetime: "TIMESTAMP", - Decimal: "DECIMAL", - } - +class SyncAuroraDataAPICursor(BaseAuroraDataAPICursor): def __init__( self, client=None, @@ -177,44 +103,7 @@ def __init__( transaction_id=None, continue_after_timeout=None, ): - self.arraysize = 1000 - self.description = None - self._client = client - self._dbname = dbname - self._aurora_cluster_arn = aurora_cluster_arn - self._secret_arn = secret_arn - self._transaction_id = transaction_id - self._current_response = None - self._iterator = None - self._paging_state = None - self._continue_after_timeout = continue_after_timeout - - def prepare_param(self, param_name, param_value): - if param_value is None: - return dict(name=param_name, value=dict(isNull=True)) - param_data_api_type = self._data_api_type_map.get(type(param_value), "stringValue") - param = dict(name=param_name, value={param_data_api_type: param_value}) - if param_data_api_type == "stringValue" and not isinstance(param_value, str): - param["value"][param_data_api_type] = str(param_value) - if type(param_value) in self._data_api_type_hint_map: - param["typeHint"] = self._data_api_type_hint_map[type(param_value)] - return param - - # if param_data_api_type == "arrayValue" and len(param_value) > 0: - # return { - # param_data_api_type: { - # self._data_api_type_map.get(type(param_value[0]), "stringValue") + "s": param_value - # } - # } - - def _set_description(self, column_metadata): - # see https://www.postgresql.org/docs/9.5/datatype.html - self.description = [] - for column in column_metadata: - col_desc = ColumnDescription( - name=column["name"], type_code=self._pg_type_map.get(column["typeName"].lower(), str) - ) - self.description.append(col_desc) + super().__init__(client, dbname, aurora_cluster_arn, secret_arn, transaction_id, continue_after_timeout) def _start_paginated_query(self, execute_statement_args, records_per_page=None): # MySQL cursors are non-scrollable (https://dev.mysql.com/doc/refman/8.0/en/cursors.html) @@ -232,40 +121,6 @@ def _start_paginated_query(self, execute_statement_args, records_per_page=None): "pg_cursor_name": pg_cursor_name, } - def _prepare_execute_args(self, operation): - execute_args = dict( - database=self._dbname, resourceArn=self._aurora_cluster_arn, secretArn=self._secret_arn, sql=operation - ) - if self._transaction_id: - execute_args["transactionId"] = self._transaction_id - return execute_args - - def _format_parameter_set(self, parameters): - if not isinstance(parameters, Mapping): - raise NotSupportedError("Expected a mapping of parameters. Array parameters are not supported.") - return [self.prepare_param(k, v) for k, v in parameters.items()] - - def _get_database_error(self, original_error): - error_msg = getattr(original_error, "response", {}).get("Error", {}).get("Message", "") - try: - res = re.search(r"Error code: (\d+); SQLState: (\d+)$", error_msg) - if res: # MySQL error - error_code = int(res.group(1)) - error_class = MySQLError.from_code(error_code) - error = error_class(error_msg) - error.response = getattr(original_error, "response", {}) - return error - res = re.search(r"ERROR: .*(?:\n |;) Position: (\d+); SQLState: (\w+)$", error_msg) - if res: # PostgreSQL error - error_code = res.group(2) - error_class = PostgreSQLError.from_code(error_code) - error = error_class(error_msg) - error.response = getattr(original_error, "response", {}) - return error - except Exception: - pass - return DatabaseError(original_error) - def execute(self, operation, parameters=None): self._current_response, self._iterator, self._paging_state = None, None, None execute_statement_args = dict(self._prepare_execute_args(operation), includeResultMetadata=True) @@ -288,25 +143,6 @@ def execute(self, operation, parameters=None): raise self._get_database_error(e) from e self._iterator = iter(self) - @property - def rowcount(self): - if self._current_response: - if "records" in self._current_response: - return len(self._current_response["records"]) - elif "numberOfRecordsUpdated" in self._current_response: - return self._current_response["numberOfRecordsUpdated"] - return -1 - - @property - def lastrowid(self): - # TODO: this may not make sense if the previous statement is not an INSERT - if self._current_response and self._current_response.get("generatedFields"): - return self._render_value(self._current_response["generatedFields"][-1]) - - def _page_input(self, iterable, page_size=1000): - iterable = iter(iterable) - return iter(lambda: list(itertools.islice(iterable, page_size)), []) - def executemany(self, operation, seq_of_parameters): logger.debug("executemany %s", reprlib.repr(operation.strip())) for batch in self._page_input(seq_of_parameters): @@ -318,42 +154,6 @@ def executemany(self, operation, seq_of_parameters): except self._client.exceptions.BadRequestException as e: raise self._get_database_error(e) from e - def _render_response(self, response): - if "records" in response: - for i, record in enumerate(response["records"]): - response["records"][i] = tuple( - self._render_value(value, col_desc=self.description[j] if self.description else None) - for j, value in enumerate(record) - ) - return response - - def _render_value(self, value, col_desc=None): - if value.get("isNull"): - return None - elif "arrayValue" in value: - if "arrayValues" in value["arrayValue"]: - return [self._render_value(nested) for nested in value["arrayValue"]["arrayValues"]] - else: - return list(value["arrayValue"].values())[0] - else: - scalar_value = list(value.values())[0] - if col_desc and col_desc.type_code in self._data_api_type_hint_map: - if col_desc.type_code == Decimal: - scalar_value = Decimal(scalar_value) - else: - try: - scalar_value = col_desc.type_code.fromisoformat(scalar_value) - except (AttributeError, ValueError): # fromisoformat not supported on Python < 3.7 - if col_desc.type_code == datetime.date: - scalar_value = datetime.datetime.strptime(scalar_value, "%Y-%m-%d").date() - elif col_desc.type_code == datetime.time: - scalar_value = datetime.datetime.strptime(scalar_value, "%H:%M:%S").time() - elif "." in scalar_value: - scalar_value = datetime.datetime.strptime(scalar_value, "%Y-%m-%d %H:%M:%S.%f") - else: - scalar_value = datetime.datetime.strptime(scalar_value, "%Y-%m-%d %H:%M:%S") - return scalar_value - def scroll(self, value, mode="relative"): if not self._paging_state: raise InterfaceError("Cursor scroll attempted but pagination is not active") @@ -416,15 +216,6 @@ def fetchmany(self, size=None): def fetchall(self): return list(self._iterator) - def setinputsizes(self, sizes): - pass - - def setoutputsize(self, size, column=None): - pass - - def close(self): - pass - def __enter__(self): return self From 5384b70436bfb8e1faaa42144e503142cafe7d5a Mon Sep 17 00:00:00 2001 From: tsuga <2888173+tsuga@users.noreply.github.com> Date: Sun, 27 Jul 2025 12:48:32 +0000 Subject: [PATCH 8/9] Update readme and setup.py --- README.rst | 26 ++++++++++++++++++++++++++ setup.py | 2 +- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 91f4bd0..8976c84 100644 --- a/README.rst +++ b/README.rst @@ -7,6 +7,10 @@ Installation pip install aurora-data-api +For async support using aiobotocore:: + + pip install aurora-data-api[async] + Prerequisites ------------- * Set up an AWS @@ -66,6 +70,19 @@ the standard main entry point, and accepts two implementation-specific keyword a cursor.execute("select * from pg_catalog.pg_tables") print(cursor.fetchall()) +For async usage (requires ``aurora-data-api[async]``):: + +.. code-block:: python + + import aurora_data_api.async_ as aurora_data_api_async + + cluster_arn = "arn:aws:rds:us-east-1:123456789012:cluster:my-aurora-serverless-cluster" + secret_arn = "arn:aws:secretsmanager:us-east-1:123456789012:secret:rds-db-credentials/MY_DB" + async with await aurora_data_api_async.connect(aurora_cluster_arn=cluster_arn, secret_arn=secret_arn, database="my_db") as conn: + async with await conn.cursor() as cursor: + await cursor.execute("select * from pg_catalog.pg_tables") + print(await cursor.fetchall()) + The cursor supports iteration (and automatically wraps the query in a server-side cursor and paginates it if required): .. code-block:: python @@ -74,6 +91,15 @@ The cursor supports iteration (and automatically wraps the query in a server-sid for row in cursor.execute("select * from pg_catalog.pg_tables"): print(row) +For async iteration:: + +.. code-block:: python + + async with await conn.cursor() as cursor: + await cursor.execute("select * from pg_catalog.pg_tables") + async for row in cursor: + print(row) + Motivation ---------- The `RDS Data API `_ is the link between the diff --git a/setup.py b/setup.py index 1981830..9636be9 100755 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ setup( name="aurora-data-api", - version="0.5.0", + version="0.6.0", url="https://github.com/chanzuckerberg/aurora-data-api", license="Apache Software License", author="Andrey Kislyuk", From 29d9ffc6999054b2958c75cf3afa7d8b2ad60807 Mon Sep 17 00:00:00 2001 From: tsuga <2888173+tsuga@users.noreply.github.com> Date: Fri, 1 Aug 2025 03:32:31 +0000 Subject: [PATCH 9/9] Revert version increment --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 9636be9..1981830 100755 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ setup( name="aurora-data-api", - version="0.6.0", + version="0.5.0", url="https://github.com/chanzuckerberg/aurora-data-api", license="Apache Software License", author="Andrey Kislyuk",