From a41d45dad6eacd1698b32f5bf1bbc928632d7835 Mon Sep 17 00:00:00 2001 From: Daniele Briggi Date: Tue, 6 Aug 2024 12:03:14 +0000 Subject: [PATCH 1/2] fix(compression): read one command at a time from socket - fix chunked rowset and decompression, chunk are compressed one by one - chunks are read from socket to complete the message - json is returned as JSON Object --- src/sqlitecloud/driver.py | 218 +++++++++++++++------------ src/tests/integration/test_client.py | 77 +++++++++- src/tests/integration/test_pubsub.py | 34 +++-- 3 files changed, 217 insertions(+), 112 deletions(-) diff --git a/src/sqlitecloud/driver.py b/src/sqlitecloud/driver.py index f951d68..08e12be 100644 --- a/src/sqlitecloud/driver.py +++ b/src/sqlitecloud/driver.py @@ -262,15 +262,15 @@ def _internal_pubsub_thread(self, connection: SQLiteCloudConnect) -> None: blen -= nread buffer += data - SQLiteCloud_number = self._internal_parse_number(buffer) - clen = SQLiteCloud_number.value + sqlitecloud_number = self._internal_parse_number(buffer) + clen = sqlitecloud_number.value if clen == 0: continue # check if read is complete # clen is the lenght parsed in the buffer # cstart is the index of the first space - cstart = SQLiteCloud_number.cstart + cstart = sqlitecloud_number.cstart if clen + cstart != tread: continue @@ -532,61 +532,80 @@ def _internal_socket_read( slicing the buffer into parts if there are special characters like "ò". """ buffer = b"" - buffer_size = 8192 + command_type = "" + command_length_value = b"" nread = 0 sock = connection.socket if main_socket else connection.pubsub_socket + # read the lenght of the command, eg: + # ?LEN , where `?` is any command type + # _ for null command + # :145 for integer command with value 145 while True: try: - data = sock.recv(buffer_size) + data = sock.recv(1) if not data: - raise SQLiteCloudException("Incomplete response from server.") + raise SQLiteCloudException( + "Incomplete response from server. Cannot read the command length." + ) except Exception as exc: raise SQLiteCloudException( - "An error occurred while reading data from the socket.", + "An error occurred while reading command length from the socket.", SQLITECLOUD_INTERNAL_ERRCODE.NETWORK, ) from exc - # the expected data length to read - # matches the string size before decoding it nread += len(data) - # update buffers buffer += data - c = chr(buffer[0]) + # first character is the type of the message + if nread == 1: + command_type = data.decode() + continue - if ( - c == SQLITECLOUD_CMD.INT.value - or c == SQLITECLOUD_CMD.FLOAT.value - or c == SQLITECLOUD_CMD.NULL.value - ): - if not buffer.endswith(b" "): - continue - elif c == SQLITECLOUD_CMD.ROWSET_CHUNK.value: - isEndOfChunk = buffer.endswith(SQLITECLOUD_ROWSET.CHUNKS_END.value) - if not isEndOfChunk: - continue - else: - SQLiteCloud_number = self._internal_parse_number(buffer) - n = SQLiteCloud_number.value - cstart = SQLiteCloud_number.cstart + # end of len value + if data == b" ": + break - can_be_zerolength = ( - c == SQLITECLOUD_CMD.BLOB.value or c == SQLITECLOUD_CMD.STRING.value - ) - if n == 0 and not can_be_zerolength: - continue - if n + cstart != nread: - continue + command_length_value += data + if ( + command_type == SQLITECLOUD_CMD.INT.value + or command_type == SQLITECLOUD_CMD.FLOAT.value + or command_type == SQLITECLOUD_CMD.NULL.value + ): return self._internal_parse_buffer(connection, buffer, len(buffer)) + command_length = int(command_length_value) + + # read the command + nread = 0 + + while nread < command_length: + buffer_size = min(command_length - nread, 8192) + + try: + data = sock.recv(buffer_size) + if not data: + raise SQLiteCloudException( + "Incomplete response from server. Cannot read the command." + ) + except Exception as exc: + raise SQLiteCloudException( + "An error occurred while reading the command from the socket.", + SQLITECLOUD_INTERNAL_ERRCODE.NETWORK, + ) from exc + + nread += len(data) + buffer += data + + return self._internal_parse_buffer(connection, buffer, len(buffer)) + def _internal_parse_number( self, buffer: bytes, index: int = 1 ) -> SQLiteCloudNumber: - SQLiteCloud_number = SQLiteCloudNumber() - SQLiteCloud_number.value = 0 + sqlitecloud_number = SQLiteCloudNumber() + sqlitecloud_number.value = 0 extvalue = 0 isext = False blen = len(buffer) @@ -602,9 +621,9 @@ def _internal_parse_number( # check for end of value if c == " ": - SQLiteCloud_number.cstart = i + 1 - SQLiteCloud_number.extcode = extvalue - return SQLiteCloud_number + sqlitecloud_number.cstart = i + 1 + sqlitecloud_number.extcode = extvalue + return sqlitecloud_number val = int(c) if c.isdigit() else 0 @@ -612,10 +631,10 @@ def _internal_parse_number( if isext: extvalue = (extvalue * 10) + val else: - SQLiteCloud_number.value = (SQLiteCloud_number.value * 10) + val + sqlitecloud_number.value = (sqlitecloud_number.value * 10) + val - SQLiteCloud_number.value = 0 - return SQLiteCloud_number + sqlitecloud_number.value = 0 + return sqlitecloud_number def _internal_parse_buffer( self, connection: SQLiteCloudConnect, buffer: bytes, blen: int @@ -665,11 +684,7 @@ def _internal_parse_buffer( if len_ == 0: return SQLiteCloudResult(SQLITECLOUD_RESULT_TYPE.RESULT_STRING, "") - tag = ( - SQLITECLOUD_RESULT_TYPE.RESULT_JSON - if cmd == SQLITECLOUD_CMD.JSON.value - else SQLITECLOUD_RESULT_TYPE.RESULT_STRING - ) + tag = SQLITECLOUD_RESULT_TYPE.RESULT_STRING if cmd == SQLITECLOUD_CMD.ZEROSTRING.value: len_ -= 1 @@ -693,6 +708,10 @@ def _internal_parse_buffer( ) elif cmd == SQLITECLOUD_CMD.BLOB.value: tag = SQLITECLOUD_RESULT_TYPE.RESULT_BLOB + elif cmd == SQLITECLOUD_CMD.JSON.value: + return SQLiteCloudResult( + SQLITECLOUD_RESULT_TYPE.RESULT_JSON, json.loads(clone) + ) clone = clone.decode() if cmd != SQLITECLOUD_CMD.BLOB.value else clone return SQLiteCloudResult(tag, clone) @@ -740,11 +759,10 @@ def _internal_parse_buffer( rowset_signature.ncols, ) - # continue parsing next chunk in the buffer - sign_len = rowset_signature.len - buffer = buffer[sign_len + len(f"/{sign_len} ") :] - if cmd == SQLITECLOUD_CMD.ROWSET_CHUNK.value and buffer: - return self._internal_parse_buffer(connection, buffer, len(buffer)) + # continue reading from the socket + # until the end-of-chunk condition + if cmd == SQLITECLOUD_CMD.ROWSET_CHUNK.value: + return self._internal_socket_read(connection) return rowset @@ -752,8 +770,8 @@ def _internal_parse_buffer( return SQLiteCloudResult(SQLITECLOUD_RESULT_TYPE.RESULT_NONE, None) elif cmd in [SQLITECLOUD_CMD.INT.value, SQLITECLOUD_CMD.FLOAT.value]: - SQLiteCloud_value = self._internal_parse_value(buffer) - clone = SQLiteCloud_value.value + sqlitecloud_value = self._internal_parse_value(buffer) + clone = sqlitecloud_value.value tag = ( SQLITECLOUD_RESULT_TYPE.RESULT_INTEGER @@ -784,6 +802,10 @@ def _internal_uncompress_data(self, buffer: bytes) -> Optional[bytes]: Returns: str: The uncompressed data. """ + # buffer may contain a sequence of compressed data + # eg, a compressed rowset split in chunks is a sequence of rowset chunks + # compressed individually, each one with its compressed header, + # rowset header and compressed data space_index = buffer.index(b" ") buffer = buffer[space_index + 1 :] @@ -821,14 +843,14 @@ def _internal_parse_array(self, buffer: bytes) -> list: r: str = [] for i in range(n): - SQLiteCloud_value = self._internal_parse_value(buffer, start) - start += SQLiteCloud_value.cellsize - r.append(SQLiteCloud_value.value) + sqlitecloud_value = self._internal_parse_value(buffer, start) + start += sqlitecloud_value.cellsize + r.append(sqlitecloud_value.value) return r def _internal_parse_value(self, buffer: bytes, index: int = 0) -> SQLiteCloudValue: - SQLiteCloud_value = SQLiteCloudValue() + sqlitecloud_value = SQLiteCloudValue() len = 0 cellsize = 0 @@ -839,14 +861,14 @@ def _internal_parse_value(self, buffer: bytes, index: int = 0) -> SQLiteCloudVal if cellsize is not None: cellsize = 2 - SQLiteCloud_value.len = len - SQLiteCloud_value.cellsize = cellsize + sqlitecloud_value.len = len + sqlitecloud_value.cellsize = cellsize - return SQLiteCloud_value + return sqlitecloud_value - SQLiteCloud_number = self._internal_parse_number(buffer, index + 1) - blen = SQLiteCloud_number.value - cstart = SQLiteCloud_number.cstart + sqlitecloud_number = self._internal_parse_number(buffer, index + 1) + blen = sqlitecloud_number.value + cstart = sqlitecloud_number.cstart # handle decimal/float cases if c == SQLITECLOUD_CMD.INT.value or c == SQLITECLOUD_CMD.FLOAT.value: @@ -854,20 +876,20 @@ def _internal_parse_value(self, buffer: bytes, index: int = 0) -> SQLiteCloudVal len = nlen - 2 cellsize = nlen - SQLiteCloud_value.value = (buffer[index + 1 : index + 1 + len]).decode() - SQLiteCloud_value.len - SQLiteCloud_value.cellsize = cellsize + sqlitecloud_value.value = (buffer[index + 1 : index + 1 + len]).decode() + sqlitecloud_value.len + sqlitecloud_value.cellsize = cellsize - return SQLiteCloud_value + return sqlitecloud_value len = blen - 1 if c == SQLITECLOUD_CMD.ZEROSTRING.value else blen cellsize = blen + cstart - index - SQLiteCloud_value.value = (buffer[cstart : cstart + len]).decode() - SQLiteCloud_value.len = len - SQLiteCloud_value.cellsize = cellsize + sqlitecloud_value.value = (buffer[cstart : cstart + len]).decode() + sqlitecloud_value.len = len + sqlitecloud_value.cellsize = cellsize - return SQLiteCloud_value + return sqlitecloud_value def _internal_parse_rowset_signature( self, buffer: bytes @@ -951,9 +973,9 @@ def _internal_parse_rowset_header( # parse column names rowset.colname = [] for i in range(ncols): - SQLiteCloud_number = self._internal_parse_number(buffer, start) - number_len = SQLiteCloud_number.value - cstart = SQLiteCloud_number.cstart + sqlitecloud_number = self._internal_parse_number(buffer, start) + number_len = sqlitecloud_number.value + cstart = sqlitecloud_number.cstart value = buffer[cstart : cstart + number_len] rowset.colname.append(value.decode()) start = cstart + number_len @@ -969,9 +991,9 @@ def _internal_parse_rowset_header( # parse declared types rowset.decltype = [] for i in range(ncols): - SQLiteCloud_number = self._internal_parse_number(buffer, start) - number_len = SQLiteCloud_number.value - cstart = SQLiteCloud_number.cstart + sqlitecloud_number = self._internal_parse_number(buffer, start) + number_len = sqlitecloud_number.value + cstart = sqlitecloud_number.cstart value = buffer[cstart : cstart + number_len] rowset.decltype.append(value.decode()) start = cstart + number_len @@ -979,9 +1001,9 @@ def _internal_parse_rowset_header( # parse database names rowset.dbname = [] for i in range(ncols): - SQLiteCloud_number = self._internal_parse_number(buffer, start) - number_len = SQLiteCloud_number.value - cstart = SQLiteCloud_number.cstart + sqlitecloud_number = self._internal_parse_number(buffer, start) + number_len = sqlitecloud_number.value + cstart = sqlitecloud_number.cstart value = buffer[cstart : cstart + number_len] rowset.dbname.append(value.decode()) start = cstart + number_len @@ -989,9 +1011,9 @@ def _internal_parse_rowset_header( # parse table names rowset.tblname = [] for i in range(ncols): - SQLiteCloud_number = self._internal_parse_number(buffer, start) - number_len = SQLiteCloud_number.value - cstart = SQLiteCloud_number.cstart + sqlitecloud_number = self._internal_parse_number(buffer, start) + number_len = sqlitecloud_number.value + cstart = sqlitecloud_number.cstart value = buffer[cstart : cstart + number_len] rowset.tblname.append(value.decode()) start = cstart + number_len @@ -999,9 +1021,9 @@ def _internal_parse_rowset_header( # parse column original names rowset.origname = [] for i in range(ncols): - SQLiteCloud_number = self._internal_parse_number(buffer, start) - number_len = SQLiteCloud_number.value - cstart = SQLiteCloud_number.cstart + sqlitecloud_number = self._internal_parse_number(buffer, start) + number_len = sqlitecloud_number.value + cstart = sqlitecloud_number.cstart value = buffer[cstart : cstart + number_len] rowset.origname.append(value.decode()) start = cstart + number_len @@ -1009,23 +1031,23 @@ def _internal_parse_rowset_header( # parse not null flags rowset.notnull = [] for i in range(ncols): - SQLiteCloud_number = self._internal_parse_number(buffer, start) - rowset.notnull.append(SQLiteCloud_number.value) - start = SQLiteCloud_number.cstart + sqlitecloud_number = self._internal_parse_number(buffer, start) + rowset.notnull.append(sqlitecloud_number.value) + start = sqlitecloud_number.cstart # parse primary key flags rowset.prikey = [] for i in range(ncols): - SQLiteCloud_number = self._internal_parse_number(buffer, start) - rowset.prikey.append(SQLiteCloud_number.value) - start = SQLiteCloud_number.cstart + sqlitecloud_number = self._internal_parse_number(buffer, start) + rowset.prikey.append(sqlitecloud_number.value) + start = sqlitecloud_number.cstart # parse autoincrement flags rowset.autoinc = [] for i in range(ncols): - SQLiteCloud_number = self._internal_parse_number(buffer, start) - rowset.autoinc.append(SQLiteCloud_number.value) - start = SQLiteCloud_number.cstart + sqlitecloud_number = self._internal_parse_number(buffer, start) + rowset.autoinc.append(sqlitecloud_number.value) + start = sqlitecloud_number.cstart return start @@ -1034,6 +1056,6 @@ def _internal_parse_rowset_values( ): # loop to parse each individual value for i in range(bound): - SQLiteCloud_value = self._internal_parse_value(buffer, start) - start += SQLiteCloud_value.cellsize - rowset.data.append(SQLiteCloud_value.value) + sqlitecloud_value = self._internal_parse_value(buffer, start) + start += sqlitecloud_value.cellsize + rowset.data.append(sqlitecloud_value.value) diff --git a/src/tests/integration/test_client.py b/src/tests/integration/test_client.py index f0a7a81..52f342a 100644 --- a/src/tests/integration/test_client.py +++ b/src/tests/integration/test_client.py @@ -1,3 +1,4 @@ +import base64 import json import os import time @@ -5,6 +6,7 @@ import pytest from sqlitecloud.client import SQLiteCloudClient +from sqlitecloud.driver import Driver from sqlitecloud.types import ( SQLITECLOUD_ERRCODE, SQLITECLOUD_INTERNAL_ERRCODE, @@ -248,7 +250,7 @@ def test_json(self, sqlitecloud_connection): "soldier: Well they don't seem to move anymore...", "supreme-commander: Oh snap, I came here to see them twerk!", ], - } == json.loads(result.get_result()) + } == result.get_result() def test_blob(self, sqlitecloud_connection): connection, client = sqlitecloud_connection @@ -257,7 +259,7 @@ def test_blob(self, sqlitecloud_connection): assert SQLITECLOUD_RESULT_TYPE.RESULT_BLOB == result.tag assert len(result.get_result()) == 1000 - def test_blob0(self, sqlitecloud_connection): + def test_blob_zero_length(self, sqlitecloud_connection): connection, client = sqlitecloud_connection result = client.exec_query("TEST BLOB0", connection) @@ -366,7 +368,7 @@ def test_max_rowset_option_to_succeed_when_rowset_is_lighter(self): assert 1 == rowset.nrows - def test_chunked_rowset(self, sqlitecloud_connection): + def test_rowset_chunk(self, sqlitecloud_connection): connection, client = sqlitecloud_connection rowset = client.exec_query("TEST ROWSET_CHUNK", connection) @@ -377,6 +379,17 @@ def test_chunked_rowset(self, sqlitecloud_connection): assert 147 == len(rowset.data) assert "key" == rowset.get_name(0) + def test_rowset_nochunk(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + + rowset = client.exec_query("TEST ROWSET_NOCHUNK", connection) + + assert SQLITECLOUD_RESULT_TYPE.RESULT_ROWSET == rowset.tag + assert 147 == rowset.nrows + assert 1 == rowset.ncols + assert 147 == len(rowset.data) + assert "key" == rowset.get_name(0) + def test_chunked_rowset_twice(self, sqlitecloud_connection): connection, client = sqlitecloud_connection rowset = client.exec_query("TEST ROWSET_CHUNK", connection) @@ -441,7 +454,7 @@ def test_query_timeout(self): client.disconnect(connection) assert e.value.errcode == SQLITECLOUD_INTERNAL_ERRCODE.NETWORK - assert e.value.errmsg == "An error occurred while reading data from the socket." + assert e.value.errmsg == "An error occurred while reading command length from the socket." def test_XXL_query(self, sqlitecloud_connection): connection, client = sqlitecloud_connection @@ -651,6 +664,62 @@ def test_compression_multiple_columns(self): assert rowset.ncols > 0 assert rowset.get_name(0) == "AlbumId" + def test_compression_big_rowset(self): + account = SQLiteCloudAccount() + account.hostname = os.getenv("SQLITE_HOST") + account.apikey = os.getenv("SQLITE_API_KEY") + account.dbname = os.getenv("SQLITE_DB") + + client = SQLiteCloudClient(cloud_account=account) + client.config.compression = True + + connection = client.open_connection() + + try: + client.exec_query("CREATE TABLE IF NOT EXISTS TestCompress (id INTEGER PRIMARY KEY, name TEXT)", connection) + client.exec_query("DELETE FROM TestCompress", connection) + + nRows = 10000 + + sql = "BEGIN; " + sql = "" + for i in range(nRows): + sql += f"INSERT INTO TestCompress (name) VALUES ('Test {i}'); " + sql += "COMMIT;" + + client.exec_query(sql, connection) + + rowset = client.exec_query( + "SELECT * from TestCompress", + connection, + ) + + assert rowset.nrows == nRows + finally: + client.disconnect(connection) + + def test_rowset_nochunk_compressed(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + + rowset = client.exec_query("TEST ROWSET_NOCHUNK_COMPRESSED", connection) + + assert SQLITECLOUD_RESULT_TYPE.RESULT_ROWSET == rowset.tag + assert 147 == rowset.nrows + assert 1 == rowset.ncols + assert 147 == len(rowset.data) + assert "key" == rowset.get_name(0) + + def test_rowset_chunk_compressed(self, sqlitecloud_connection): + connection, client = sqlitecloud_connection + + rowset = client.exec_query("TEST ROWSET_CHUNK_COMPRESSED", connection) + + assert SQLITECLOUD_RESULT_TYPE.RESULT_ROWSET == rowset.tag + assert 147 == rowset.nrows + assert 1 == rowset.ncols + assert 147 == len(rowset.data) + assert "key" == rowset.get_name(0) + def test_exec_statement_with_named_placeholder(self, sqlitecloud_connection): connection, client = sqlitecloud_connection diff --git a/src/tests/integration/test_pubsub.py b/src/tests/integration/test_pubsub.py index be4b98a..4058201 100644 --- a/src/tests/integration/test_pubsub.py +++ b/src/tests/integration/test_pubsub.py @@ -26,16 +26,23 @@ def assert_callback(conn, result, data): if isinstance(result, SQLiteCloudResultSet): assert result.tag == SQLITECLOUD_RESULT_TYPE.RESULT_JSON + assert isinstance(result.get_result(), dict) + + content = result.get_result() + assert content["channel"] == channel + assert len(content["payload"]) > 0 + assert content["payload"] == "somedata2" + assert data == ["somedata"] callback_called = True flag.set() pubsub = SQLiteCloudPubSub() - type = SQLITECLOUD_PUBSUB_SUBJECT.CHANNEL + subject_type = SQLITECLOUD_PUBSUB_SUBJECT.CHANNEL channel = "channel" + str(uuid.uuid4()) pubsub.create_channel(connection, channel) - pubsub.listen(connection, type, channel, assert_callback, ["somedata"]) + pubsub.listen(connection, subject_type, channel, assert_callback, ["somedata"]) pubsub.notify_channel(connection, channel, "somedata2") @@ -48,16 +55,16 @@ def test_unlisten_channel(self, sqlitecloud_connection): connection, _ = sqlitecloud_connection pubsub = SQLiteCloudPubSub() - type = SQLITECLOUD_PUBSUB_SUBJECT.CHANNEL + subject_type = SQLITECLOUD_PUBSUB_SUBJECT.CHANNEL channel_name = "channel" + str(uuid.uuid4()) pubsub.create_channel(connection, channel_name) - pubsub.listen(connection, type, channel_name, lambda conn, result, data: None) + pubsub.listen(connection, subject_type, channel_name, lambda conn, result, data: None) result = pubsub.list_connections(connection) assert channel_name in result.data - pubsub.unlisten(connection, type, channel_name) + pubsub.unlisten(connection, subject_type, channel_name) result = pubsub.list_connections(connection) @@ -116,11 +123,11 @@ def assert_callback(conn, result, data): flag.set() pubsub = SQLiteCloudPubSub() - type = SQLITECLOUD_PUBSUB_SUBJECT.CHANNEL + subject_type = SQLITECLOUD_PUBSUB_SUBJECT.CHANNEL channel = "channel" + str(uuid.uuid4()) pubsub.create_channel(connection, channel, if_not_exists=True) - pubsub.listen(connection, type, channel, assert_callback) + pubsub.listen(connection, subject_type, channel, assert_callback) pubsub.set_pubsub_only(connection) @@ -150,16 +157,23 @@ def assert_callback(conn, result, data): if isinstance(result, SQLiteCloudResultSet): assert result.tag == SQLITECLOUD_RESULT_TYPE.RESULT_JSON - assert new_name in result.get_result() + assert isinstance(result.get_result(), dict) + + content = result.get_result() + assert content["channel"] == "genres" + assert len(content["payload"]) > 0 + assert content["payload"][0]["Name"] == new_name + assert content["payload"][0]["type"] == 'UPDATE' + assert data == ["somedata"] callback_called = True flag.set() pubsub = SQLiteCloudPubSub() - type = SQLITECLOUD_PUBSUB_SUBJECT.TABLE + subject_type = SQLITECLOUD_PUBSUB_SUBJECT.TABLE new_name = "Rock" + str(uuid.uuid4()) - pubsub.listen(connection, type, "genres", assert_callback, ["somedata"]) + pubsub.listen(connection, subject_type, "genres", assert_callback, ["somedata"]) client.exec_query( f"UPDATE genres SET Name = '{new_name}' WHERE GenreId = 1;", connection From 421e41de84d3d891f4df5cb18fa1e90dfcba7aee Mon Sep 17 00:00:00 2001 From: Daniele Briggi Date: Tue, 6 Aug 2024 12:36:13 +0000 Subject: [PATCH 2/2] refact: types.py is used by Py --- bandit-baseline.json | 121 +++++++++++-------- src/sqlitecloud/client.py | 6 +- src/sqlitecloud/{types.py => datatypes.py} | 31 +---- src/sqlitecloud/dbapi2.py | 9 +- src/sqlitecloud/download.py | 2 +- src/sqlitecloud/driver.py | 13 +- src/sqlitecloud/pubsub.py | 2 +- src/sqlitecloud/resultset.py | 22 +++- src/sqlitecloud/upload.py | 2 +- src/tests/conftest.py | 2 +- src/tests/integration/test_client.py | 80 ++++++------ src/tests/integration/test_dbapi2.py | 2 +- src/tests/integration/test_download.py | 2 +- src/tests/integration/test_pubsub.py | 15 +-- src/tests/integration/test_sqlite3_parity.py | 2 +- src/tests/unit/test_dbapi2.py | 11 +- src/tests/unit/test_driver.py | 2 +- src/tests/unit/test_resultset.py | 8 +- src/tests/unit/test_types.py | 2 +- 19 files changed, 176 insertions(+), 158 deletions(-) rename src/sqlitecloud/{types.py => datatypes.py} (92%) diff --git a/bandit-baseline.json b/bandit-baseline.json index b3b2040..aee9ade 100644 --- a/bandit-baseline.json +++ b/bandit-baseline.json @@ -1,17 +1,17 @@ { "errors": [], - "generated_at": "2024-06-03T07:52:17Z", + "generated_at": "2024-08-06T12:35:09Z", "metrics": { "_totals": { "CONFIDENCE.HIGH": 0.0, - "CONFIDENCE.LOW": 2.0, + "CONFIDENCE.LOW": 3.0, "CONFIDENCE.MEDIUM": 1.0, "CONFIDENCE.UNDEFINED": 0.0, "SEVERITY.HIGH": 0.0, "SEVERITY.LOW": 1.0, - "SEVERITY.MEDIUM": 2.0, + "SEVERITY.MEDIUM": 3.0, "SEVERITY.UNDEFINED": 0.0, - "loc": 3405, + "loc": 3497, "nosec": 0 }, "src/setup.py": { @@ -50,19 +50,19 @@ "loc": 104, "nosec": 0 }, - "src/sqlitecloud/dbapi2.py": { + "src/sqlitecloud/datatypes.py": { "CONFIDENCE.HIGH": 0.0, "CONFIDENCE.LOW": 0.0, - "CONFIDENCE.MEDIUM": 0.0, + "CONFIDENCE.MEDIUM": 1.0, "CONFIDENCE.UNDEFINED": 0.0, "SEVERITY.HIGH": 0.0, - "SEVERITY.LOW": 0.0, + "SEVERITY.LOW": 1.0, "SEVERITY.MEDIUM": 0.0, "SEVERITY.UNDEFINED": 0.0, - "loc": 377, + "loc": 177, "nosec": 0 }, - "src/sqlitecloud/download.py": { + "src/sqlitecloud/dbapi2.py": { "CONFIDENCE.HIGH": 0.0, "CONFIDENCE.LOW": 0.0, "CONFIDENCE.MEDIUM": 0.0, @@ -71,10 +71,10 @@ "SEVERITY.LOW": 0.0, "SEVERITY.MEDIUM": 0.0, "SEVERITY.UNDEFINED": 0.0, - "loc": 32, + "loc": 376, "nosec": 0 }, - "src/sqlitecloud/driver.py": { + "src/sqlitecloud/download.py": { "CONFIDENCE.HIGH": 0.0, "CONFIDENCE.LOW": 0.0, "CONFIDENCE.MEDIUM": 0.0, @@ -83,10 +83,10 @@ "SEVERITY.LOW": 0.0, "SEVERITY.MEDIUM": 0.0, "SEVERITY.UNDEFINED": 0.0, - "loc": 765, + "loc": 32, "nosec": 0 }, - "src/sqlitecloud/pubsub.py": { + "src/sqlitecloud/driver.py": { "CONFIDENCE.HIGH": 0.0, "CONFIDENCE.LOW": 0.0, "CONFIDENCE.MEDIUM": 0.0, @@ -95,10 +95,10 @@ "SEVERITY.LOW": 0.0, "SEVERITY.MEDIUM": 0.0, "SEVERITY.UNDEFINED": 0.0, - "loc": 54, + "loc": 787, "nosec": 0 }, - "src/sqlitecloud/resultset.py": { + "src/sqlitecloud/pubsub.py": { "CONFIDENCE.HIGH": 0.0, "CONFIDENCE.LOW": 0.0, "CONFIDENCE.MEDIUM": 0.0, @@ -107,19 +107,19 @@ "SEVERITY.LOW": 0.0, "SEVERITY.MEDIUM": 0.0, "SEVERITY.UNDEFINED": 0.0, - "loc": 80, + "loc": 56, "nosec": 0 }, - "src/sqlitecloud/types.py": { + "src/sqlitecloud/resultset.py": { "CONFIDENCE.HIGH": 0.0, "CONFIDENCE.LOW": 0.0, - "CONFIDENCE.MEDIUM": 1.0, + "CONFIDENCE.MEDIUM": 0.0, "CONFIDENCE.UNDEFINED": 0.0, "SEVERITY.HIGH": 0.0, - "SEVERITY.LOW": 1.0, + "SEVERITY.LOW": 0.0, "SEVERITY.MEDIUM": 0.0, "SEVERITY.UNDEFINED": 0.0, - "loc": 194, + "loc": 99, "nosec": 0 }, "src/sqlitecloud/upload.py": { @@ -172,14 +172,14 @@ }, "src/tests/integration/test_client.py": { "CONFIDENCE.HIGH": 0.0, - "CONFIDENCE.LOW": 0.0, + "CONFIDENCE.LOW": 1.0, "CONFIDENCE.MEDIUM": 0.0, "CONFIDENCE.UNDEFINED": 0.0, "SEVERITY.HIGH": 0.0, "SEVERITY.LOW": 0.0, - "SEVERITY.MEDIUM": 0.0, + "SEVERITY.MEDIUM": 1.0, "SEVERITY.UNDEFINED": 0.0, - "loc": 492, + "loc": 543, "nosec": 0 }, "src/tests/integration/test_dbapi2.py": { @@ -239,7 +239,7 @@ "SEVERITY.LOW": 0.0, "SEVERITY.MEDIUM": 1.0, "SEVERITY.UNDEFINED": 0.0, - "loc": 118, + "loc": 129, "nosec": 0 }, "src/tests/integration/test_sqlite3_parity.py": { @@ -287,7 +287,7 @@ "SEVERITY.LOW": 0.0, "SEVERITY.MEDIUM": 0.0, "SEVERITY.UNDEFINED": 0.0, - "loc": 242, + "loc": 241, "nosec": 0 }, "src/tests/unit/test_driver.py": { @@ -311,7 +311,7 @@ "SEVERITY.LOW": 0.0, "SEVERITY.MEDIUM": 0.0, "SEVERITY.UNDEFINED": 0.0, - "loc": 113, + "loc": 119, "nosec": 0 }, "src/tests/unit/test_types.py": { @@ -329,51 +329,66 @@ }, "results": [ { - "code": "107 class SQLiteCloudAccount:\n108 def __init__(\n109 self,\n110 username: Optional[str] = \"\",\n111 password: Optional[str] = \"\",\n112 hostname: str = \"\",\n113 dbname: Optional[str] = \"\",\n114 port: int = SQLITECLOUD_DEFAULT.PORT.value,\n115 apikey: Optional[str] = \"\",\n116 ) -> None:\n117 # User name is required unless connectionstring is provided\n118 self.username = username\n119 # Password is required unless connection string is provided\n120 self.password = password\n121 # Password is hashed\n122 self.password_hashed = False\n123 # API key instead of username and password\n124 self.apikey = apikey\n125 # Name of database to open\n126 self.dbname = dbname\n127 # Like mynode.sqlitecloud.io\n128 self.hostname = hostname\n129 self.port = port\n130 \n", + "code": "87 class SQLiteCloudAccount:\n88 def __init__(\n89 self,\n90 username: Optional[str] = \"\",\n91 password: Optional[str] = \"\",\n92 hostname: str = \"\",\n93 dbname: Optional[str] = \"\",\n94 port: int = SQLITECLOUD_DEFAULT.PORT.value,\n95 apikey: Optional[str] = \"\",\n96 ) -> None:\n97 # User name is required unless connectionstring is provided\n98 self.username = username\n99 # Password is required unless connection string is provided\n100 self.password = password\n101 # Password is hashed\n102 self.password_hashed = False\n103 # API key instead of username and password\n104 self.apikey = apikey\n105 # Name of database to open\n106 self.dbname = dbname\n107 # Like mynode.sqlitecloud.io\n108 self.hostname = hostname\n109 self.port = port\n110 \n", "col_offset": 4, - "filename": "src/sqlitecloud/types.py", + "filename": "src/sqlitecloud/datatypes.py", "issue_confidence": "MEDIUM", "issue_severity": "LOW", "issue_text": "Possible hardcoded password: ''", - "line_number": 108, + "line_number": 88, "line_range": [ + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, 108, - 109, - 110, - 111, - 112, - 113, - 114, - 115, - 116, - 117, - 118, - 119, - 120, - 121, - 122, - 123, - 124, - 125, - 126, - 127, - 128, - 129 + 109 ], "more_info": "https://bandit.readthedocs.io/en/latest/plugins/b107_hardcoded_password_default.html", "test_id": "B107", "test_name": "hardcoded_password_default" }, { - "code": "164 client.exec_query(\n165 f\"UPDATE genres SET Name = '{new_name}' WHERE GenreId = 1;\", connection\n166 )\n", + "code": "639 for i in range(nRows):\n640 sql += f\"INSERT INTO TestCompress (name) VALUES ('Test {i}'); \"\n641 \n", + "col_offset": 23, + "filename": "src/tests/integration/test_client.py", + "issue_confidence": "LOW", + "issue_severity": "MEDIUM", + "issue_text": "Possible SQL injection vector through string-based query construction.", + "line_number": 640, + "line_range": [ + 640 + ], + "more_info": "https://bandit.readthedocs.io/en/latest/plugins/b608_hardcoded_sql_expressions.html", + "test_id": "B608", + "test_name": "hardcoded_sql_expressions" + }, + { + "code": "179 client.exec_query(\n180 f\"UPDATE genres SET Name = '{new_name}' WHERE GenreId = 1;\", connection\n181 )\n", "col_offset": 12, "filename": "src/tests/integration/test_pubsub.py", "issue_confidence": "LOW", "issue_severity": "MEDIUM", "issue_text": "Possible SQL injection vector through string-based query construction.", - "line_number": 165, + "line_number": 180, "line_range": [ - 165 + 180 ], "more_info": "https://bandit.readthedocs.io/en/latest/plugins/b608_hardcoded_sql_expressions.html", "test_id": "B608", diff --git a/src/sqlitecloud/client.py b/src/sqlitecloud/client.py index 375fd36..6e4a897 100644 --- a/src/sqlitecloud/client.py +++ b/src/sqlitecloud/client.py @@ -3,15 +3,15 @@ """ from typing import Dict, Optional, Tuple, Union -from sqlitecloud.driver import Driver -from sqlitecloud.resultset import SQLiteCloudResultSet -from sqlitecloud.types import ( +from sqlitecloud.datatypes import ( SQLiteCloudAccount, SQLiteCloudConfig, SQLiteCloudConnect, SQLiteCloudDataTypes, SQLiteCloudException, ) +from sqlitecloud.driver import Driver +from sqlitecloud.resultset import SQLiteCloudResultSet class SQLiteCloudClient: diff --git a/src/sqlitecloud/types.py b/src/sqlitecloud/datatypes.py similarity index 92% rename from src/sqlitecloud/types.py rename to src/sqlitecloud/datatypes.py index 9e29142..5d7c1dd 100644 --- a/src/sqlitecloud/types.py +++ b/src/sqlitecloud/datatypes.py @@ -1,16 +1,17 @@ -import types from asyncio import AbstractEventLoop from enum import Enum from typing import Any, Callable, Dict, Optional, Union from urllib import parse +from .resultset import SQLiteCloudResultSet + # Basic types supported by SQLite Cloud APIs SQLiteCloudDataTypes = Union[str, int, bool, Dict[Union[str, int], Any], bytes, None] class SQLITECLOUD_DEFAULT(Enum): PORT = 8860 - TIMEOUT = 12 + TIMEOUT = 30 UPLOAD_SIZE = 512 * 1024 @@ -37,14 +38,6 @@ class SQLITECLOUD_ROWSET(Enum): CHUNKS_END = b"/6 0 0 0 " -class SQLITECLOUD_VALUE_TYPE(Enum): - INTEGER = "INTEGER" - FLOAT = "REAL" - TEXT = "TEXT" - BLOB = "BLOB" - NULL = "NULL" - - class SQLITECLOUD_INTERNAL_ERRCODE(Enum): """ Clients error codes. @@ -68,19 +61,6 @@ class SQLITECLOUD_ERRCODE(Enum): RAFT = 10006 -class SQLITECLOUD_RESULT_TYPE(Enum): - RESULT_OK = 0 - RESULT_ERROR = 1 - RESULT_STRING = 2 - RESULT_INTEGER = 3 - RESULT_FLOAT = 4 - RESULT_ROWSET = 5 - RESULT_ARRAY = 6 - RESULT_NONE = 7 - RESULT_JSON = 8 - RESULT_BLOB = 9 - - class SQLITECLOUD_PUBSUB_SUBJECT(Enum): """ Subjects that can be subscribed to by PubSub. @@ -141,7 +121,7 @@ def __init__(self): self.pubsub_socket: any = None self.pubsub_callback: Callable[ - [SQLiteCloudConnect, Optional[types.SqliteCloudResultSet], Optional[any]], + [SQLiteCloudConnect, Optional[SQLiteCloudResultSet], Optional[any]], None, ] = None self.pubsub_data: any = None @@ -210,8 +190,6 @@ def _parse_connection_string(self, connection_string) -> None: value = bool(value) elif value.isdigit(): value = int(value) - else: - value = value # alias if opt == "nonlinearizable": @@ -248,6 +226,7 @@ def _parse_connection_string(self, connection_string) -> None: class SQLiteCloudException(Exception): def __init__(self, message: str, code: int = -1, xerrcode: int = 0) -> None: + super().__init__(message) self.errmsg = str(message) self.errcode = code self.xerrcode = xerrcode diff --git a/src/sqlitecloud/dbapi2.py b/src/sqlitecloud/dbapi2.py index 06ba402..af56a74 100644 --- a/src/sqlitecloud/dbapi2.py +++ b/src/sqlitecloud/dbapi2.py @@ -16,16 +16,15 @@ overload, ) -from sqlitecloud.driver import Driver -from sqlitecloud.resultset import SQLiteCloudResult -from sqlitecloud.types import ( - SQLITECLOUD_RESULT_TYPE, +from sqlitecloud.datatypes import ( SQLiteCloudAccount, SQLiteCloudConfig, SQLiteCloudConnect, SQLiteCloudDataTypes, SQLiteCloudException, ) +from sqlitecloud.driver import Driver +from sqlitecloud.resultset import SQLITECLOUD_RESULT_TYPE, SQLiteCloudResult # Question mark style, e.g. ...WHERE name=? # Module also supports Named style, e.g. ...WHERE name=:name @@ -478,7 +477,7 @@ def __next__(self) -> Optional[Tuple[Any]]: and self._resultset.data and self._iter_row < self._resultset.nrows ): - out: tuple[Any] = () + out: Tuple[Any] = () for col in range(self._resultset.ncols): out += (self._resultset.get_value(self._iter_row, col),) diff --git a/src/sqlitecloud/download.py b/src/sqlitecloud/download.py index 37eb8b6..3955cc9 100644 --- a/src/sqlitecloud/download.py +++ b/src/sqlitecloud/download.py @@ -1,8 +1,8 @@ import logging from io import BufferedWriter +from sqlitecloud.datatypes import SQLiteCloudConnect from sqlitecloud.driver import Driver -from sqlitecloud.types import SQLiteCloudConnect def xCallback( diff --git a/src/sqlitecloud/driver.py b/src/sqlitecloud/driver.py index 08e12be..905f38e 100644 --- a/src/sqlitecloud/driver.py +++ b/src/sqlitecloud/driver.py @@ -9,12 +9,10 @@ import lz4.block -from sqlitecloud.resultset import SQLiteCloudResult, SQLiteCloudResultSet -from sqlitecloud.types import ( +from sqlitecloud.datatypes import ( SQLITECLOUD_CMD, SQLITECLOUD_DEFAULT, SQLITECLOUD_INTERNAL_ERRCODE, - SQLITECLOUD_RESULT_TYPE, SQLITECLOUD_ROWSET, SQLiteCloudConfig, SQLiteCloudConnect, @@ -24,11 +22,14 @@ SQLiteCloudRowsetSignature, SQLiteCloudValue, ) +from sqlitecloud.resultset import ( + SQLITECLOUD_RESULT_TYPE, + SQLiteCloudResult, + SQLiteCloudResultSet, +) class Driver: - SQLiteCloud_DEFAULT_UPLOAD_SIZE = 512 * 1024 - def __init__(self) -> None: # Used while parsing chunked rowset self._rowset: SQLiteCloudResult = None @@ -759,7 +760,7 @@ def _internal_parse_buffer( rowset_signature.ncols, ) - # continue reading from the socket + # continue reading from the socket # until the end-of-chunk condition if cmd == SQLITECLOUD_CMD.ROWSET_CHUNK.value: return self._internal_socket_read(connection) diff --git a/src/sqlitecloud/pubsub.py b/src/sqlitecloud/pubsub.py index 7999908..b61cc4a 100644 --- a/src/sqlitecloud/pubsub.py +++ b/src/sqlitecloud/pubsub.py @@ -1,8 +1,8 @@ from typing import Callable, Optional +from sqlitecloud.datatypes import SQLITECLOUD_PUBSUB_SUBJECT, SQLiteCloudConnect from sqlitecloud.driver import Driver from sqlitecloud.resultset import SQLiteCloudResultSet -from sqlitecloud.types import SQLITECLOUD_PUBSUB_SUBJECT, SQLiteCloudConnect class SQLiteCloudPubSub: diff --git a/src/sqlitecloud/resultset.py b/src/sqlitecloud/resultset.py index c6160d7..6fcc481 100644 --- a/src/sqlitecloud/resultset.py +++ b/src/sqlitecloud/resultset.py @@ -1,6 +1,26 @@ +from enum import Enum from typing import Any, Dict, List, Optional -from sqlitecloud.types import SQLITECLOUD_RESULT_TYPE, SQLITECLOUD_VALUE_TYPE + +class SQLITECLOUD_VALUE_TYPE(Enum): + INTEGER = "INTEGER" + FLOAT = "REAL" + TEXT = "TEXT" + BLOB = "BLOB" + NULL = "NULL" + + +class SQLITECLOUD_RESULT_TYPE(Enum): + RESULT_OK = 0 + RESULT_ERROR = 1 + RESULT_STRING = 2 + RESULT_INTEGER = 3 + RESULT_FLOAT = 4 + RESULT_ROWSET = 5 + RESULT_ARRAY = 6 + RESULT_NONE = 7 + RESULT_JSON = 8 + RESULT_BLOB = 9 class SQLiteCloudResult: diff --git a/src/sqlitecloud/upload.py b/src/sqlitecloud/upload.py index 3dd10e1..5226965 100644 --- a/src/sqlitecloud/upload.py +++ b/src/sqlitecloud/upload.py @@ -3,8 +3,8 @@ from io import BufferedReader from typing import Optional +from sqlitecloud.datatypes import SQLiteCloudConnect from sqlitecloud.driver import Driver -from sqlitecloud.types import SQLiteCloudConnect def xCallback(fd: BufferedReader, blen: int, ntot: int, nprogress: int) -> bytes: diff --git a/src/tests/conftest.py b/src/tests/conftest.py index b1db511..8d45371 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -5,7 +5,7 @@ import sqlitecloud from sqlitecloud.client import SQLiteCloudClient -from sqlitecloud.types import SQLiteCloudAccount, SQLiteCloudConnect +from sqlitecloud.datatypes import SQLiteCloudAccount, SQLiteCloudConnect @pytest.fixture(autouse=True) diff --git a/src/tests/integration/test_client.py b/src/tests/integration/test_client.py index 52f342a..3680e1d 100644 --- a/src/tests/integration/test_client.py +++ b/src/tests/integration/test_client.py @@ -1,20 +1,17 @@ -import base64 -import json import os import time import pytest from sqlitecloud.client import SQLiteCloudClient -from sqlitecloud.driver import Driver -from sqlitecloud.types import ( +from sqlitecloud.datatypes import ( SQLITECLOUD_ERRCODE, SQLITECLOUD_INTERNAL_ERRCODE, - SQLITECLOUD_RESULT_TYPE, SQLiteCloudAccount, SQLiteCloudConnect, SQLiteCloudException, ) +from sqlitecloud.resultset import SQLITECLOUD_RESULT_TYPE class TestClient: @@ -454,7 +451,10 @@ def test_query_timeout(self): client.disconnect(connection) assert e.value.errcode == SQLITECLOUD_INTERNAL_ERRCODE.NETWORK - assert e.value.errmsg == "An error occurred while reading command length from the socket." + assert ( + e.value.errmsg + == "An error occurred while reading command length from the socket." + ) def test_XXL_query(self, sqlitecloud_connection): connection, client = sqlitecloud_connection @@ -616,6 +616,40 @@ def test_stress_test_20x_batched_selects(self, sqlitecloud_connection): query_ms < self.EXPECT_SPEED_MS ), f"{num_queries}x batched selects, {query_ms}ms per query" + def test_big_rowset(self): + account = SQLiteCloudAccount() + account.hostname = os.getenv("SQLITE_HOST") + account.apikey = os.getenv("SQLITE_API_KEY") + account.dbname = os.getenv("SQLITE_DB") + + client = SQLiteCloudClient(cloud_account=account) + + connection = client.open_connection() + + try: + client.exec_query( + "CREATE TABLE IF NOT EXISTS TestCompress (id INTEGER PRIMARY KEY, name TEXT)", + connection, + ) + client.exec_query("DELETE FROM TestCompress", connection) + + nRows = 1000 + + sql = "" + for i in range(nRows): + sql += f"INSERT INTO TestCompress (name) VALUES ('Test {i}'); " + + client.exec_query(sql, connection) + + rowset = client.exec_query( + "SELECT * from TestCompress", + connection, + ) + + assert rowset.nrows == nRows + finally: + client.disconnect(connection) + def test_compression_single_column(self): account = SQLiteCloudAccount() account.hostname = os.getenv("SQLITE_HOST") @@ -664,40 +698,6 @@ def test_compression_multiple_columns(self): assert rowset.ncols > 0 assert rowset.get_name(0) == "AlbumId" - def test_compression_big_rowset(self): - account = SQLiteCloudAccount() - account.hostname = os.getenv("SQLITE_HOST") - account.apikey = os.getenv("SQLITE_API_KEY") - account.dbname = os.getenv("SQLITE_DB") - - client = SQLiteCloudClient(cloud_account=account) - client.config.compression = True - - connection = client.open_connection() - - try: - client.exec_query("CREATE TABLE IF NOT EXISTS TestCompress (id INTEGER PRIMARY KEY, name TEXT)", connection) - client.exec_query("DELETE FROM TestCompress", connection) - - nRows = 10000 - - sql = "BEGIN; " - sql = "" - for i in range(nRows): - sql += f"INSERT INTO TestCompress (name) VALUES ('Test {i}'); " - sql += "COMMIT;" - - client.exec_query(sql, connection) - - rowset = client.exec_query( - "SELECT * from TestCompress", - connection, - ) - - assert rowset.nrows == nRows - finally: - client.disconnect(connection) - def test_rowset_nochunk_compressed(self, sqlitecloud_connection): connection, client = sqlitecloud_connection diff --git a/src/tests/integration/test_dbapi2.py b/src/tests/integration/test_dbapi2.py index e97f383..53b2a18 100644 --- a/src/tests/integration/test_dbapi2.py +++ b/src/tests/integration/test_dbapi2.py @@ -4,7 +4,7 @@ import pytest import sqlitecloud -from sqlitecloud.types import ( +from sqlitecloud.datatypes import ( SQLITECLOUD_INTERNAL_ERRCODE, SQLiteCloudAccount, SQLiteCloudException, diff --git a/src/tests/integration/test_download.py b/src/tests/integration/test_download.py index ddd5f46..7c7c8af 100644 --- a/src/tests/integration/test_download.py +++ b/src/tests/integration/test_download.py @@ -4,7 +4,7 @@ import pytest from sqlitecloud import download -from sqlitecloud.types import SQLITECLOUD_ERRCODE, SQLiteCloudException +from sqlitecloud.datatypes import SQLITECLOUD_ERRCODE, SQLiteCloudException class TestDownload: diff --git a/src/tests/integration/test_pubsub.py b/src/tests/integration/test_pubsub.py index 4058201..03d18c9 100644 --- a/src/tests/integration/test_pubsub.py +++ b/src/tests/integration/test_pubsub.py @@ -3,14 +3,13 @@ import pytest -from sqlitecloud.pubsub import SQLiteCloudPubSub -from sqlitecloud.resultset import SQLiteCloudResultSet -from sqlitecloud.types import ( +from sqlitecloud.datatypes import ( SQLITECLOUD_ERRCODE, SQLITECLOUD_PUBSUB_SUBJECT, - SQLITECLOUD_RESULT_TYPE, SQLiteCloudException, ) +from sqlitecloud.pubsub import SQLiteCloudPubSub +from sqlitecloud.resultset import SQLITECLOUD_RESULT_TYPE, SQLiteCloudResultSet class TestPubSub: @@ -59,7 +58,9 @@ def test_unlisten_channel(self, sqlitecloud_connection): channel_name = "channel" + str(uuid.uuid4()) pubsub.create_channel(connection, channel_name) - pubsub.listen(connection, subject_type, channel_name, lambda conn, result, data: None) + pubsub.listen( + connection, subject_type, channel_name, lambda conn, result, data: None + ) result = pubsub.list_connections(connection) assert channel_name in result.data @@ -158,12 +159,12 @@ def assert_callback(conn, result, data): if isinstance(result, SQLiteCloudResultSet): assert result.tag == SQLITECLOUD_RESULT_TYPE.RESULT_JSON assert isinstance(result.get_result(), dict) - + content = result.get_result() assert content["channel"] == "genres" assert len(content["payload"]) > 0 assert content["payload"][0]["Name"] == new_name - assert content["payload"][0]["type"] == 'UPDATE' + assert content["payload"][0]["type"] == "UPDATE" assert data == ["somedata"] callback_called = True diff --git a/src/tests/integration/test_sqlite3_parity.py b/src/tests/integration/test_sqlite3_parity.py index e80fc52..bc71fb9 100644 --- a/src/tests/integration/test_sqlite3_parity.py +++ b/src/tests/integration/test_sqlite3_parity.py @@ -3,7 +3,7 @@ import pytest -from sqlitecloud.types import SQLiteCloudException +from sqlitecloud.datatypes import SQLiteCloudException class TestSQLite3FeatureParity: diff --git a/src/tests/unit/test_dbapi2.py b/src/tests/unit/test_dbapi2.py index c76aad5..2ab873b 100644 --- a/src/tests/unit/test_dbapi2.py +++ b/src/tests/unit/test_dbapi2.py @@ -3,15 +3,14 @@ import sqlitecloud from sqlitecloud import Cursor -from sqlitecloud.dbapi2 import Connection -from sqlitecloud.driver import Driver -from sqlitecloud.resultset import SQLiteCloudResult -from sqlitecloud.types import ( - SQLITECLOUD_RESULT_TYPE, +from sqlitecloud.datatypes import ( SQLiteCloudAccount, SQLiteCloudConfig, SQLiteCloudException, ) +from sqlitecloud.dbapi2 import Connection +from sqlitecloud.driver import Driver +from sqlitecloud.resultset import SQLITECLOUD_RESULT_TYPE, SQLiteCloudResult def test_connect_with_account_and_config(mocker: MockerFixture): @@ -296,7 +295,7 @@ def test_iterator(self, mocker): assert list(cursor) == [("myname1",), ("myname2",)] def test_row_factory(self, mocker): - conn = Connection(mocker.patch("sqlitecloud.types.SQLiteCloudConnect")) + conn = Connection(mocker.patch("sqlitecloud.datatypes.SQLiteCloudConnect")) conn.row_factory = lambda x, y: {"name": y[0]} result = SQLiteCloudResult(SQLITECLOUD_RESULT_TYPE.RESULT_ROWSET) diff --git a/src/tests/unit/test_driver.py b/src/tests/unit/test_driver.py index f49ba32..796b0aa 100644 --- a/src/tests/unit/test_driver.py +++ b/src/tests/unit/test_driver.py @@ -1,8 +1,8 @@ import pytest from pytest_mock import MockerFixture +from sqlitecloud.datatypes import SQLiteCloudAccount, SQLiteCloudConfig from sqlitecloud.driver import Driver -from sqlitecloud.types import SQLiteCloudAccount, SQLiteCloudConfig class TestDriver: diff --git a/src/tests/unit/test_resultset.py b/src/tests/unit/test_resultset.py index f89f274..18a43f2 100644 --- a/src/tests/unit/test_resultset.py +++ b/src/tests/unit/test_resultset.py @@ -1,7 +1,11 @@ import pytest -from sqlitecloud.resultset import SQLiteCloudResult, SQLiteCloudResultSet -from sqlitecloud.types import SQLITECLOUD_RESULT_TYPE, SQLITECLOUD_VALUE_TYPE +from sqlitecloud.resultset import ( + SQLITECLOUD_RESULT_TYPE, + SQLITECLOUD_VALUE_TYPE, + SQLiteCloudResult, + SQLiteCloudResultSet, +) class TestSQLiteCloudResult: diff --git a/src/tests/unit/test_types.py b/src/tests/unit/test_types.py index e08e569..ff4cb4c 100644 --- a/src/tests/unit/test_types.py +++ b/src/tests/unit/test_types.py @@ -1,6 +1,6 @@ import pytest -from sqlitecloud.types import SQLiteCloudConfig +from sqlitecloud.datatypes import SQLiteCloudConfig class TestSQLiteCloudConfig: