From 7db2e85d967850482d534b353cf92b64121f9f99 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Wed, 13 Aug 2025 11:25:32 +0530 Subject: [PATCH 1/3] Adding fetchval() in cursor class --- mssql_python/cursor.py | 40 +++++ tests/test_004_cursor.py | 366 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 406 insertions(+) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index fc48a5a9c..217e04755 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -970,6 +970,46 @@ def nextset(self) -> Union[bool, None]: return True + def fetchval(self): + """ + Fetch the first column of the first row if there are results. + + This is a convenience method for queries that return a single value, + such as SELECT COUNT(*) FROM table, SELECT MAX(id) FROM table, etc. + + Returns: + The value of the first column of the first row, or None if no rows + are available or the first column value is NULL. + + Raises: + Exception: If the cursor is closed. + + Example: + >>> count = cursor.execute('SELECT COUNT(*) FROM users').fetchval() + >>> max_id = cursor.execute('SELECT MAX(id) FROM products').fetchval() + >>> name = cursor.execute('SELECT name FROM users WHERE id = ?', user_id).fetchval() + + Note: + This is a convenience extension beyond the DB-API 2.0 specification. + After calling fetchval(), the cursor position advances by one row, + just like fetchone(). + """ + self._check_closed() # Check if the cursor is closed + + # Fetch the first row + row = self.fetchone() + + # If no row is available, return None + if row is None: + return None + + # If the row has no columns, return None (shouldn't happen in normal cases) + if len(row) == 0: + return None + + # Return the first column value (could be None if the column value is NULL) + return row[0] + def __del__(self): """ Destructor to ensure the cursor is closed when it is no longer needed. diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 247dc8b5a..2fc09e73d 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -2454,6 +2454,372 @@ def test_nextset_diagnostics(cursor, db_connection): print(f"DIAGNOSTIC INFO: {e}") # Don't fail the test - this is just for diagnostics +def test_fetchval_basic_functionality(cursor, db_connection): + """Test basic fetchval functionality with simple queries""" + try: + # Test with COUNT query + cursor.execute("SELECT COUNT(*) FROM sys.databases") + count = cursor.fetchval() + assert isinstance(count, int), "fetchval should return integer for COUNT(*)" + assert count > 0, "COUNT(*) should return positive number" + + # Test with literal value + cursor.execute("SELECT 42") + value = cursor.fetchval() + assert value == 42, "fetchval should return the literal value" + + # Test with string literal + cursor.execute("SELECT 'Hello World'") + text = cursor.fetchval() + assert text == 'Hello World', "fetchval should return string literal" + + except Exception as e: + pytest.fail(f"Basic fetchval functionality test failed: {e}") + +def test_fetchval_different_data_types(cursor, db_connection): + """Test fetchval with different SQL data types""" + try: + # Create test table with different data types + drop_table_if_exists(cursor, "#pytest_fetchval_types") + cursor.execute(""" + CREATE TABLE #pytest_fetchval_types ( + int_col INTEGER, + float_col FLOAT, + decimal_col DECIMAL(10,2), + varchar_col VARCHAR(50), + nvarchar_col NVARCHAR(50), + bit_col BIT, + datetime_col DATETIME, + date_col DATE, + time_col TIME + ) + """) + + # Insert test data + cursor.execute(""" + INSERT INTO #pytest_fetchval_types VALUES + (123, 45.67, 89.12, 'ASCII text', N'Unicode text', 1, + '2024-05-20 12:34:56', '2024-05-20', '12:34:56') + """) + db_connection.commit() + + # Test different data types + test_cases = [ + ("SELECT int_col FROM #pytest_fetchval_types", 123, int), + ("SELECT float_col FROM #pytest_fetchval_types", 45.67, float), + ("SELECT decimal_col FROM #pytest_fetchval_types", decimal.Decimal('89.12'), decimal.Decimal), + ("SELECT varchar_col FROM #pytest_fetchval_types", 'ASCII text', str), + ("SELECT nvarchar_col FROM #pytest_fetchval_types", 'Unicode text', str), + ("SELECT bit_col FROM #pytest_fetchval_types", 1, int), + ("SELECT datetime_col FROM #pytest_fetchval_types", datetime(2024, 5, 20, 12, 34, 56), datetime), + ("SELECT date_col FROM #pytest_fetchval_types", date(2024, 5, 20), date), + ("SELECT time_col FROM #pytest_fetchval_types", time(12, 34, 56), time), + ] + + for query, expected_value, expected_type in test_cases: + cursor.execute(query) + result = cursor.fetchval() + assert isinstance(result, expected_type), f"fetchval should return {expected_type.__name__} for {query}" + if isinstance(expected_value, float): + assert abs(result - expected_value) < 0.01, f"Float values should be approximately equal for {query}" + else: + assert result == expected_value, f"fetchval should return {expected_value} for {query}" + + except Exception as e: + pytest.fail(f"fetchval data types test failed: {e}") + finally: + try: + cursor.execute("DROP TABLE #pytest_fetchval_types") + db_connection.commit() + except: + pass + +def test_fetchval_null_values(cursor, db_connection): + """Test fetchval with NULL values""" + try: + # Test explicit NULL + cursor.execute("SELECT NULL") + result = cursor.fetchval() + assert result is None, "fetchval should return None for NULL value" + + # Test NULL from table + drop_table_if_exists(cursor, "#pytest_fetchval_null") + cursor.execute("CREATE TABLE #pytest_fetchval_null (col VARCHAR(50))") + cursor.execute("INSERT INTO #pytest_fetchval_null VALUES (NULL)") + db_connection.commit() + + cursor.execute("SELECT col FROM #pytest_fetchval_null") + result = cursor.fetchval() + assert result is None, "fetchval should return None for NULL column value" + + except Exception as e: + pytest.fail(f"fetchval NULL values test failed: {e}") + finally: + try: + cursor.execute("DROP TABLE #pytest_fetchval_null") + db_connection.commit() + except: + pass + +def test_fetchval_no_results(cursor, db_connection): + """Test fetchval when query returns no rows""" + try: + # Create empty table + drop_table_if_exists(cursor, "#pytest_fetchval_empty") + cursor.execute("CREATE TABLE #pytest_fetchval_empty (col INTEGER)") + db_connection.commit() + + # Query empty table + cursor.execute("SELECT col FROM #pytest_fetchval_empty") + result = cursor.fetchval() + assert result is None, "fetchval should return None when no rows are returned" + + # Query with WHERE clause that matches nothing + cursor.execute("SELECT col FROM #pytest_fetchval_empty WHERE col = 999") + result = cursor.fetchval() + assert result is None, "fetchval should return None when WHERE clause matches no rows" + + except Exception as e: + pytest.fail(f"fetchval no results test failed: {e}") + finally: + try: + cursor.execute("DROP TABLE #pytest_fetchval_empty") + db_connection.commit() + except: + pass + +def test_fetchval_multiple_columns(cursor, db_connection): + """Test fetchval with queries that return multiple columns (should return first column)""" + try: + drop_table_if_exists(cursor, "#pytest_fetchval_multi") + cursor.execute("CREATE TABLE #pytest_fetchval_multi (col1 INTEGER, col2 VARCHAR(50), col3 FLOAT)") + cursor.execute("INSERT INTO #pytest_fetchval_multi VALUES (100, 'second column', 3.14)") + db_connection.commit() + + # Query multiple columns - should return first column + cursor.execute("SELECT col1, col2, col3 FROM #pytest_fetchval_multi") + result = cursor.fetchval() + assert result == 100, "fetchval should return first column value when multiple columns are selected" + + # Test with different order + cursor.execute("SELECT col2, col1, col3 FROM #pytest_fetchval_multi") + result = cursor.fetchval() + assert result == 'second column', "fetchval should return first column value regardless of column order" + + except Exception as e: + pytest.fail(f"fetchval multiple columns test failed: {e}") + finally: + try: + cursor.execute("DROP TABLE #pytest_fetchval_multi") + db_connection.commit() + except: + pass + +def test_fetchval_multiple_rows(cursor, db_connection): + """Test fetchval with queries that return multiple rows (should return first row, first column)""" + try: + drop_table_if_exists(cursor, "#pytest_fetchval_rows") + cursor.execute("CREATE TABLE #pytest_fetchval_rows (col INTEGER)") + cursor.execute("INSERT INTO #pytest_fetchval_rows VALUES (10)") + cursor.execute("INSERT INTO #pytest_fetchval_rows VALUES (20)") + cursor.execute("INSERT INTO #pytest_fetchval_rows VALUES (30)") + db_connection.commit() + + # Query multiple rows - should return first row's first column + cursor.execute("SELECT col FROM #pytest_fetchval_rows ORDER BY col") + result = cursor.fetchval() + assert result == 10, "fetchval should return first row's first column value" + + # Verify cursor position advanced by one row + next_row = cursor.fetchone() + assert next_row[0] == 20, "Cursor should advance by one row after fetchval" + + except Exception as e: + pytest.fail(f"fetchval multiple rows test failed: {e}") + finally: + try: + cursor.execute("DROP TABLE #pytest_fetchval_rows") + db_connection.commit() + except: + pass + +def test_fetchval_method_chaining(cursor, db_connection): + """Test fetchval with method chaining from execute""" + try: + # Test method chaining - execute returns cursor, so we can chain fetchval + result = cursor.execute("SELECT 42").fetchval() + assert result == 42, "fetchval should work with method chaining from execute" + + # Test with parameterized query + result = cursor.execute("SELECT ?", 123).fetchval() + assert result == 123, "fetchval should work with method chaining on parameterized queries" + + except Exception as e: + pytest.fail(f"fetchval method chaining test failed: {e}") + +def test_fetchval_closed_cursor(db_connection): + """Test fetchval on closed cursor should raise exception""" + try: + cursor = db_connection.cursor() + cursor.close() + + with pytest.raises(Exception) as exc_info: + cursor.fetchval() + + assert "closed" in str(exc_info.value).lower(), "fetchval on closed cursor should raise exception mentioning cursor is closed" + + except Exception as e: + if "closed" not in str(e).lower(): + pytest.fail(f"fetchval closed cursor test failed: {e}") + +def test_fetchval_rownumber_tracking(cursor, db_connection): + """Test that fetchval properly updates rownumber tracking""" + try: + drop_table_if_exists(cursor, "#pytest_fetchval_rownumber") + cursor.execute("CREATE TABLE #pytest_fetchval_rownumber (col INTEGER)") + cursor.execute("INSERT INTO #pytest_fetchval_rownumber VALUES (1)") + cursor.execute("INSERT INTO #pytest_fetchval_rownumber VALUES (2)") + db_connection.commit() + + # Execute query to set up result set + cursor.execute("SELECT col FROM #pytest_fetchval_rownumber ORDER BY col") + + # Check initial rownumber + initial_rownumber = cursor.rownumber + + # Use fetchval + result = cursor.fetchval() + assert result == 1, "fetchval should return first row value" + + # Check that rownumber was incremented + assert cursor.rownumber == initial_rownumber + 1, "fetchval should increment rownumber" + + # Verify next fetch gets the second row + next_row = cursor.fetchone() + assert next_row[0] == 2, "Next fetchone should return second row after fetchval" + + except Exception as e: + pytest.fail(f"fetchval rownumber tracking test failed: {e}") + finally: + try: + cursor.execute("DROP TABLE #pytest_fetchval_rownumber") + db_connection.commit() + except: + pass + +def test_fetchval_aggregate_functions(cursor, db_connection): + """Test fetchval with common aggregate functions""" + try: + drop_table_if_exists(cursor, "#pytest_fetchval_agg") + cursor.execute("CREATE TABLE #pytest_fetchval_agg (value INTEGER)") + cursor.execute("INSERT INTO #pytest_fetchval_agg VALUES (10), (20), (30), (40), (50)") + db_connection.commit() + + # Test various aggregate functions + test_cases = [ + ("SELECT COUNT(*) FROM #pytest_fetchval_agg", 5), + ("SELECT SUM(value) FROM #pytest_fetchval_agg", 150), + ("SELECT AVG(value) FROM #pytest_fetchval_agg", 30), + ("SELECT MIN(value) FROM #pytest_fetchval_agg", 10), + ("SELECT MAX(value) FROM #pytest_fetchval_agg", 50), + ] + + for query, expected in test_cases: + cursor.execute(query) + result = cursor.fetchval() + if isinstance(expected, float): + assert abs(result - expected) < 0.01, f"Aggregate function result should match for {query}" + else: + assert result == expected, f"Aggregate function result should be {expected} for {query}" + + except Exception as e: + pytest.fail(f"fetchval aggregate functions test failed: {e}") + finally: + try: + cursor.execute("DROP TABLE #pytest_fetchval_agg") + db_connection.commit() + except: + pass + +def test_fetchval_empty_result_set_edge_cases(cursor, db_connection): + """Test fetchval edge cases with empty result sets""" + try: + # Test with conditional that never matches + cursor.execute("SELECT 1 WHERE 1 = 0") + result = cursor.fetchval() + assert result is None, "fetchval should return None for impossible condition" + + # Test with CASE statement that could return NULL + cursor.execute("SELECT CASE WHEN 1 = 0 THEN 'never' ELSE NULL END") + result = cursor.fetchval() + assert result is None, "fetchval should return None for CASE returning NULL" + + # Test with subquery returning no rows + cursor.execute("SELECT (SELECT COUNT(*) FROM sys.databases WHERE name = 'nonexistent_db_name_12345')") + result = cursor.fetchval() + assert result == 0, "fetchval should return 0 for COUNT with no matches" + + except Exception as e: + pytest.fail(f"fetchval empty result set edge cases test failed: {e}") + +def test_fetchval_error_scenarios(cursor, db_connection): + """Test fetchval error scenarios and recovery""" + try: + # Test fetchval after successful execute + cursor.execute("SELECT 'test'") + result = cursor.fetchval() + assert result == 'test', "fetchval should work after successful execute" + + # Test fetchval on cursor without prior execute should raise exception + cursor2 = db_connection.cursor() + try: + result = cursor2.fetchval() + # If this doesn't raise an exception, that's also acceptable behavior + # depending on the implementation + except Exception: + # Expected - cursor might not have a result set + pass + finally: + cursor2.close() + + except Exception as e: + pytest.fail(f"fetchval error scenarios test failed: {e}") + +def test_fetchval_performance_common_patterns(cursor, db_connection): + """Test fetchval with common performance-related patterns""" + try: + drop_table_if_exists(cursor, "#pytest_fetchval_perf") + cursor.execute("CREATE TABLE #pytest_fetchval_perf (id INTEGER IDENTITY(1,1), data VARCHAR(100))") + + # Insert some test data + for i in range(10): + cursor.execute("INSERT INTO #pytest_fetchval_perf (data) VALUES (?)", f"data_{i}") + db_connection.commit() + + # Test EXISTS pattern + cursor.execute("SELECT CASE WHEN EXISTS(SELECT 1 FROM #pytest_fetchval_perf WHERE data = 'data_5') THEN 1 ELSE 0 END") + exists_result = cursor.fetchval() + assert exists_result == 1, "EXISTS pattern should return 1 when record exists" + + # Test TOP 1 pattern + cursor.execute("SELECT TOP 1 id FROM #pytest_fetchval_perf ORDER BY id") + top_result = cursor.fetchval() + assert top_result == 1, "TOP 1 pattern should return first record" + + # Test scalar subquery pattern + cursor.execute("SELECT (SELECT COUNT(*) FROM #pytest_fetchval_perf)") + count_result = cursor.fetchval() + assert count_result == 10, "Scalar subquery should return correct count" + + except Exception as e: + pytest.fail(f"fetchval performance patterns test failed: {e}") + finally: + try: + cursor.execute("DROP TABLE #pytest_fetchval_perf") + db_connection.commit() + except: + pass + def test_close(db_connection): """Test closing the cursor""" try: From c3352b2cb1613ce61ddc88de677ec0f9f8224003 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Mon, 25 Aug 2025 15:28:03 +0530 Subject: [PATCH 2/3] Resolving comments --- mssql_python/cursor.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 217e04755..f975bb389 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -999,16 +999,7 @@ def fetchval(self): # Fetch the first row row = self.fetchone() - # If no row is available, return None - if row is None: - return None - - # If the row has no columns, return None (shouldn't happen in normal cases) - if len(row) == 0: - return None - - # Return the first column value (could be None if the column value is NULL) - return row[0] + return None if row is None else row[0] def __del__(self): """ From 1e614a69cd0f67f1873678f70d0c6e3b22a51175 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar <61936179+jahnvi480@users.noreply.github.com> Date: Wed, 27 Aug 2025 17:05:37 +0530 Subject: [PATCH 3/3] FEAT: Adding commit and rollback in cursor (#179) ### Work Item / Issue Reference > [AB#34922](https://sqlclientdrivers.visualstudio.com/c6d89619-62de-46a0-8b46-70b92a84d85e/_workitems/edit/34922) > [AB#34923](https://sqlclientdrivers.visualstudio.com/c6d89619-62de-46a0-8b46-70b92a84d85e/_workitems/edit/34923) ------------------------------------------------------------------- ### Summary This pull request adds two convenience methods to the `Cursor` class in `mssql_python/cursor.py`, making it easier to manage transactions directly from the cursor object without needing to access the underlying connection. This improves usability for developers working with the cursor API. Transaction management enhancements: * Added a `commit` method to the `Cursor` class, allowing users to commit transactions directly from the cursor. This delegates to the underlying connection's `commit` method and provides error handling if the cursor is closed. * Added a `rollback` method to the `Cursor` class, allowing users to roll back transactions directly from the cursor. This delegates to the underlying connection's `rollback` method and provides error handling if the cursor is closed. --------- Co-authored-by: Jahnvi Thakkar --- mssql_python/constants.py | 9 +- mssql_python/cursor.py | 403 ++++- mssql_python/pybind/ddbc_bindings.cpp | 205 ++- mssql_python/pybind/ddbc_bindings.h | 14 +- mssql_python/row.py | 30 +- tests/test_004_cursor.py | 1950 ++++++++++++++++++++++++- 6 files changed, 2570 insertions(+), 41 deletions(-) diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 81e60d37e..20c8f6636 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -97,7 +97,6 @@ class ConstantsDDBC(Enum): SQL_ATTR_ROW_ARRAY_SIZE = 27 SQL_ATTR_ROWS_FETCHED_PTR = 26 SQL_ATTR_ROW_STATUS_PTR = 25 - SQL_FETCH_NEXT = 1 SQL_ROW_SUCCESS = 0 SQL_ROW_SUCCESS_WITH_INFO = 1 SQL_ROW_NOROW = 100 @@ -117,6 +116,14 @@ class ConstantsDDBC(Enum): SQL_NULLABLE = 1 SQL_MAX_NUMERIC_LEN = 16 + SQL_FETCH_NEXT = 1 + SQL_FETCH_FIRST = 2 + SQL_FETCH_LAST = 3 + SQL_FETCH_PRIOR = 4 + SQL_FETCH_ABSOLUTE = 5 + SQL_FETCH_RELATIVE = 6 + SQL_FETCH_BOOKMARK = 8 + class AuthType(Enum): """Constants for authentication types""" INTERACTIVE = "activedirectoryinteractive" diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index f975bb389..6e1bfcee2 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -8,7 +8,6 @@ - Do not use a cursor after it is closed, or after its parent connection is closed. - Use close() to release resources held by the cursor as soon as it is no longer needed. """ -import ctypes import decimal import uuid import datetime @@ -16,7 +15,7 @@ from mssql_python.constants import ConstantsDDBC as ddbc_sql_const from mssql_python.helpers import check_error, log from mssql_python import ddbc_bindings -from mssql_python.exceptions import InterfaceError +from mssql_python.exceptions import InterfaceError, NotSupportedError, ProgrammingError from .row import Row @@ -77,8 +76,12 @@ def __init__(self, connection) -> None: # Therefore, it must be a list with exactly one bool element. # rownumber attribute - self._rownumber = -1 # Track the current row index in the result set + self._rownumber = -1 # DB-API extension: last returned row index, -1 before first + self._next_row_index = 0 # internal: index of the next row the driver will return (0-based) self._has_result_set = False # Track if we have an active result set + self._skip_increment_for_next_fetch = False # Track if we need to skip incrementing the row index + + self.messages = [] # Store diagnostic messages def _is_unicode_string(self, param): """ @@ -452,6 +455,9 @@ def close(self) -> None: if self.closed: raise Exception("Cursor is already closed.") + # Clear messages per DBAPI + self.messages = [] + if self.hstmt: self.hstmt.free() self.hstmt = None @@ -594,18 +600,21 @@ def connection(self): def _reset_rownumber(self): """Reset the rownumber tracking when starting a new result set.""" self._rownumber = -1 + self._next_row_index = 0 self._has_result_set = True + self._skip_increment_for_next_fetch = False def _increment_rownumber(self): """ - Increment the rownumber by 1. - - This should be called after each fetch operation to keep track of the current row index. + Called after a successful fetch from the driver. Keep both counters consistent. """ if self._has_result_set: - self._rownumber += 1 + # driver returned one row, so the next row index increments by 1 + self._next_row_index += 1 + # rownumber is last returned row index + self._rownumber = self._next_row_index - 1 else: - raise InterfaceError("Cannot increment rownumber: no active result set.") + raise InterfaceError("Cannot increment rownumber: no active result set.", "No active result set.") # Will be used when we add support for scrollable cursors def _decrement_rownumber(self): @@ -620,8 +629,8 @@ def _decrement_rownumber(self): else: self._rownumber = -1 else: - raise InterfaceError("Cannot decrement rownumber: no active result set.") - + raise InterfaceError("Cannot decrement rownumber: no active result set.", "No active result set.") + def _clear_rownumber(self): """ Clear the rownumber tracking. @@ -630,6 +639,7 @@ def _clear_rownumber(self): """ self._rownumber = -1 self._has_result_set = False + self._skip_increment_for_next_fetch = False def __iter__(self): """ @@ -693,6 +703,9 @@ def execute( if reset_cursor: self._reset_cursor() + # Clear any previous messages + self.messages = [] + param_info = ddbc_bindings.ParamInfo parameters_type = [] @@ -740,7 +753,14 @@ def execute( self.is_stmt_prepared, use_prepare, ) + + # Check for errors but don't raise exceptions for info/warning messages check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + + # Capture any diagnostic messages (SQL_SUCCESS_WITH_INFO, etc.) + if self.hstmt: + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) + self.last_executed_stmt = operation # Update rowcount after execution @@ -752,8 +772,10 @@ def execute( # Reset rownumber for new result set (only for SELECT statements) if self.description: # If we have column descriptions, it's likely a SELECT + self.rowcount = -1 self._reset_rownumber() else: + self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt) self._clear_rownumber() # Return self for method chaining @@ -820,7 +842,10 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: """ self._check_closed() self._reset_cursor() - + + # Clear any previous messages + self.messages = [] + if not seq_of_parameters: self.rowcount = 0 return @@ -852,13 +877,19 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: ) check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + # Capture any diagnostic messages after execution + if self.hstmt: + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) + self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt) self.last_executed_stmt = operation self._initialize_description() if self.description: + self.rowcount = -1 self._reset_rownumber() else: + self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt) self._clear_rownumber() def fetchone(self) -> Union[None, Row]: @@ -875,14 +906,22 @@ def fetchone(self) -> Union[None, Row]: try: ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data) + if self.hstmt: + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) + if ret == ddbc_sql_const.SQL_NO_DATA.value: return None - # Only increment rownumber for successful fetch with data - self._increment_rownumber() + # Update internal position after successful fetch + if self._skip_increment_for_next_fetch: + self._skip_increment_for_next_fetch = False + self._next_row_index += 1 + else: + self._increment_rownumber() - # Create and return a Row object - return Row(row_data, self.description) + # Create and return a Row object, passing column name map if available + column_map = getattr(self, '_column_name_map', None) + return Row(row_data, self.description, column_map) except Exception as e: # On error, don't increment rownumber - rethrow the error raise e @@ -898,6 +937,8 @@ def fetchmany(self, size: int = None) -> List[Row]: List of Row objects. """ self._check_closed() # Check if the cursor is closed + if not self._has_result_set and self.description: + self._reset_rownumber() if size is None: size = self.arraysize @@ -909,14 +950,20 @@ def fetchmany(self, size: int = None) -> List[Row]: rows_data = [] try: ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size) + + if self.hstmt: + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) + # Update rownumber for the number of rows actually fetched if rows_data and self._has_result_set: - for _ in rows_data: - self._increment_rownumber() + # advance counters by number of rows actually returned + self._next_row_index += len(rows_data) + self._rownumber = self._next_row_index - 1 # Convert raw data to Row objects - return [Row(row_data, self.description) for row_data in rows_data] + column_map = getattr(self, '_column_name_map', None) + return [Row(row_data, self.description, column_map) for row_data in rows_data] except Exception as e: # On error, don't increment rownumber - rethrow the error raise e @@ -929,19 +976,26 @@ def fetchall(self) -> List[Row]: List of Row objects. """ self._check_closed() # Check if the cursor is closed + if not self._has_result_set and self.description: + self._reset_rownumber() # Fetch raw data rows_data = [] try: ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data) + + if self.hstmt: + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) + # Update rownumber for the number of rows actually fetched if rows_data and self._has_result_set: - for _ in rows_data: - self._increment_rownumber() + self._next_row_index += len(rows_data) + self._rownumber = self._next_row_index - 1 # Convert raw data to Row objects - return [Row(row_data, self.description) for row_data in rows_data] + column_map = getattr(self, '_column_name_map', None) + return [Row(row_data, self.description, column_map) for row_data in rows_data] except Exception as e: # On error, don't increment rownumber - rethrow the error raise e @@ -958,6 +1012,9 @@ def nextset(self) -> Union[bool, None]: """ self._check_closed() # Check if the cursor is closed + # Clear messages per DBAPI + self.messages = [] + # Skip to the next result set ret = ddbc_bindings.DDBCSQLMoreResults(self.hstmt) check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) @@ -996,11 +1053,74 @@ def fetchval(self): """ self._check_closed() # Check if the cursor is closed + # Check if this is a result-producing statement + if not self.description: + # Non-result-set statement (INSERT, UPDATE, DELETE, etc.) + return None + # Fetch the first row row = self.fetchone() return None if row is None else row[0] + def commit(self): + """ + Commit all SQL statements executed on the connection that created this cursor. + + This is a convenience method that calls commit() on the underlying connection. + It affects all cursors created by the same connection since the last commit/rollback. + + The benefit is that many uses can now just use the cursor and not have to track + the connection object. + + Raises: + Exception: If the cursor is closed or if the commit operation fails. + + Example: + >>> cursor.execute("INSERT INTO users (name) VALUES (?)", "John") + >>> cursor.commit() # Commits the INSERT + + Note: + This is equivalent to calling connection.commit() but provides convenience + for code that only has access to the cursor object. + """ + self._check_closed() # Check if the cursor is closed + + # Clear messages per DBAPI + self.messages = [] + + # Delegate to the connection's commit method + self._connection.commit() + + def rollback(self): + """ + Roll back all SQL statements executed on the connection that created this cursor. + + This is a convenience method that calls rollback() on the underlying connection. + It affects all cursors created by the same connection since the last commit/rollback. + + The benefit is that many uses can now just use the cursor and not have to track + the connection object. + + Raises: + Exception: If the cursor is closed or if the rollback operation fails. + + Example: + >>> cursor.execute("INSERT INTO users (name) VALUES (?)", "John") + >>> cursor.rollback() # Rolls back the INSERT + + Note: + This is equivalent to calling connection.rollback() but provides convenience + for code that only has access to the cursor object. + """ + self._check_closed() # Check if the cursor is closed + + # Clear messages per DBAPI + self.messages = [] + + # Delegate to the connection's rollback method + self._connection.rollback() + def __del__(self): """ Destructor to ensure the cursor is closed when it is no longer needed. @@ -1012,4 +1132,243 @@ def __del__(self): self.close() except Exception as e: # Don't raise an exception in __del__, just log it - log('error', "Error during cursor cleanup in __del__: %s", e) \ No newline at end of file + log('error', "Error during cursor cleanup in __del__: %s", e) + + def scroll(self, value: int, mode: str = 'relative') -> None: + """ + Scroll using SQLFetchScroll only, matching test semantics: + - relative(N>0): consume N rows; rownumber = previous + N; next fetch returns the following row. + - absolute(-1): before first (rownumber = -1), no data consumed. + - absolute(0): position so next fetch returns first row; rownumber stays 0 even after that fetch. + - absolute(k>0): next fetch returns row index k (0-based); rownumber == k after scroll. + """ + self._check_closed() + + # Clear messages per DBAPI + self.messages = [] + + if mode not in ('relative', 'absolute'): + raise ProgrammingError("Invalid scroll mode", + f"mode must be 'relative' or 'absolute', got '{mode}'") + if not self._has_result_set: + raise ProgrammingError("No active result set", + "Cannot scroll: no result set available. Execute a query first.") + if not isinstance(value, int): + raise ProgrammingError("Invalid scroll value type", + f"scroll value must be an integer, got {type(value).__name__}") + + # Relative backward not supported + if mode == 'relative' and value < 0: + raise NotSupportedError("Backward scrolling not supported", + f"Cannot move backward by {value} rows on a forward-only cursor") + + row_data: list = [] + + # Absolute special cases + if mode == 'absolute': + if value == -1: + # Before first + ddbc_bindings.DDBCSQLFetchScroll(self.hstmt, + ddbc_sql_const.SQL_FETCH_ABSOLUTE.value, + 0, row_data) + self._rownumber = -1 + self._next_row_index = 0 + return + if value == 0: + # Before first, but tests want rownumber==0 pre and post the next fetch + ddbc_bindings.DDBCSQLFetchScroll(self.hstmt, + ddbc_sql_const.SQL_FETCH_ABSOLUTE.value, + 0, row_data) + self._rownumber = 0 + self._next_row_index = 0 + self._skip_increment_for_next_fetch = True + return + + try: + if mode == 'relative': + if value == 0: + return + ret = ddbc_bindings.DDBCSQLFetchScroll(self.hstmt, + ddbc_sql_const.SQL_FETCH_RELATIVE.value, + value, row_data) + if ret == ddbc_sql_const.SQL_NO_DATA.value: + raise IndexError("Cannot scroll to specified position: end of result set reached") + # Consume N rows; last-returned index advances by N + self._rownumber = self._rownumber + value + self._next_row_index = self._rownumber + 1 + return + + # absolute(k>0): map Python k (0-based next row) to ODBC ABSOLUTE k (1-based), + # intentionally passing k so ODBC fetches row #k (1-based), i.e., 0-based (k-1), + # leaving the NEXT fetch to return 0-based index k. + ret = ddbc_bindings.DDBCSQLFetchScroll(self.hstmt, + ddbc_sql_const.SQL_FETCH_ABSOLUTE.value, + value, row_data) + if ret == ddbc_sql_const.SQL_NO_DATA.value: + raise IndexError(f"Cannot scroll to position {value}: end of result set reached") + + # Tests expect rownumber == value after absolute(value) + # Next fetch should return row index 'value' + self._rownumber = value + self._next_row_index = value + + except Exception as e: + if isinstance(e, (IndexError, NotSupportedError)): + raise + raise IndexError(f"Scroll operation failed: {e}") from e + + def skip(self, count: int) -> None: + """ + Skip the next count records in the query result set. + + Args: + count: Number of records to skip. + + Raises: + IndexError: If attempting to skip past the end of the result set. + ProgrammingError: If count is not an integer. + NotSupportedError: If attempting to skip backwards. + """ + from mssql_python.exceptions import ProgrammingError, NotSupportedError + + self._check_closed() + + # Clear messages + self.messages = [] + + # Simply delegate to the scroll method with 'relative' mode + self.scroll(count, 'relative') + + def _execute_tables(self, stmt_handle, catalog_name=None, schema_name=None, table_name=None, + table_type=None, search_escape=None): + """ + Execute SQLTables ODBC function to retrieve table metadata. + + Args: + stmt_handle: ODBC statement handle + catalog_name: The catalog name pattern + schema_name: The schema name pattern + table_name: The table name pattern + table_type: The table type filter + search_escape: The escape character for pattern matching + """ + # Convert None values to empty strings for ODBC + catalog = "" if catalog_name is None else catalog_name + schema = "" if schema_name is None else schema_name + table = "" if table_name is None else table_name + types = "" if table_type is None else table_type + + # Call the ODBC SQLTables function + retcode = ddbc_bindings.DDBCSQLTables( + stmt_handle, + catalog, + schema, + table, + types + ) + + # Check return code and handle errors + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, stmt_handle, retcode) + + # Capture any diagnostic messages + if stmt_handle: + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(stmt_handle)) + + def tables(self, table=None, catalog=None, schema=None, tableType=None): + """ + Returns information about tables in the database that match the given criteria using + the SQLTables ODBC function. + + Args: + table (str, optional): The table name pattern. Default is None (all tables). + catalog (str, optional): The catalog name. Default is None. + schema (str, optional): The schema name pattern. Default is None. + tableType (str or list, optional): The table type filter. Default is None. + Example: "TABLE" or ["TABLE", "VIEW"] + + Returns: + list: A list of Row objects containing table information with these columns: + - table_cat: Catalog name + - table_schem: Schema name + - table_name: Table name + - table_type: Table type (e.g., "TABLE", "VIEW") + - remarks: Comments about the table + + Notes: + This method only processes the standard five columns as defined in the ODBC + specification. Any additional columns that might be returned by specific ODBC + drivers are not included in the result set. + + Example: + # Get all tables in the database + tables = cursor.tables() + + # Get all tables in schema 'dbo' + tables = cursor.tables(schema='dbo') + + # Get table named 'Customers' + tables = cursor.tables(table='Customers') + + # Get all views + tables = cursor.tables(tableType='VIEW') + """ + self._check_closed() + + # Clear messages + self.messages = [] + + # Always reset the cursor first to ensure clean state + self._reset_cursor() + + # Format table_type parameter - SQLTables expects comma-separated string + table_type_str = None + if tableType is not None: + if isinstance(tableType, (list, tuple)): + table_type_str = ",".join(tableType) + else: + table_type_str = str(tableType) + + # Call SQLTables via the helper method + self._execute_tables( + self.hstmt, + catalog_name=catalog, + schema_name=schema, + table_name=table, + table_type=table_type_str + ) + + # Initialize description from column metadata + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + except Exception: + # If describe fails, create a manual description for the standard columns + column_types = [str, str, str, str, str] + self.description = [ + ("table_cat", column_types[0], None, 128, 128, 0, True), + ("table_schem", column_types[1], None, 128, 128, 0, True), + ("table_name", column_types[2], None, 128, 128, 0, False), + ("table_type", column_types[3], None, 128, 128, 0, False), + ("remarks", column_types[4], None, 254, 254, 0, True) + ] + + # Define column names in ODBC standard order + column_names = [ + "table_cat", "table_schem", "table_name", "table_type", "remarks" + ] + + # Fetch all rows + rows_data = [] + ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data) + + # Create a column map for attribute access + column_map = {name: i for i, name in enumerate(column_names)} + + # Create Row objects with the column map + result_rows = [] + for row_data in rows_data: + row = Row(row_data, self.description, column_map) + result_rows.append(row) + + return result_rows \ No newline at end of file diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 1b37b8f0f..b5cabd4bf 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -134,6 +134,7 @@ SQLFreeStmtFunc SQLFreeStmt_ptr = nullptr; // Diagnostic APIs SQLGetDiagRecFunc SQLGetDiagRec_ptr = nullptr; +SQLTablesFunc SQLTables_ptr = nullptr; namespace { @@ -786,6 +787,7 @@ DriverHandle LoadDriverOrThrowException() { SQLFreeStmt_ptr = GetFunctionPointer(handle, "SQLFreeStmt"); SQLGetDiagRec_ptr = GetFunctionPointer(handle, "SQLGetDiagRecW"); + SQLTables_ptr = GetFunctionPointer(handle, "SQLTablesW"); bool success = SQLAllocHandle_ptr && SQLSetEnvAttr_ptr && SQLSetConnectAttr_ptr && @@ -796,7 +798,7 @@ DriverHandle LoadDriverOrThrowException() { SQLGetData_ptr && SQLNumResultCols_ptr && SQLBindCol_ptr && SQLDescribeCol_ptr && SQLMoreResults_ptr && SQLColAttribute_ptr && SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr && - SQLFreeStmt_ptr && SQLGetDiagRec_ptr; + SQLFreeStmt_ptr && SQLGetDiagRec_ptr && SQLTables_ptr; if (!success) { ThrowStdException("Failed to load required function pointers from driver."); @@ -901,6 +903,65 @@ ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRET return errorInfo; } +py::list SQLGetAllDiagRecords(SqlHandlePtr handle) { + LOG("Retrieving all diagnostic records"); + if (!SQLGetDiagRec_ptr) { + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); + } + + py::list records; + SQLHANDLE rawHandle = handle->get(); + SQLSMALLINT handleType = handle->type(); + + // Iterate through all available diagnostic records + for (SQLSMALLINT recNumber = 1; ; recNumber++) { + SQLWCHAR sqlState[6] = {0}; + SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0}; + SQLINTEGER nativeError = 0; + SQLSMALLINT messageLen = 0; + + SQLRETURN diagReturn = SQLGetDiagRec_ptr( + handleType, rawHandle, recNumber, sqlState, &nativeError, + message, SQL_MAX_MESSAGE_LENGTH, &messageLen); + + if (diagReturn == SQL_NO_DATA || !SQL_SUCCEEDED(diagReturn)) + break; + +#if defined(_WIN32) + // On Windows, create a formatted UTF-8 string for state+error + char stateWithError[50]; + sprintf(stateWithError, "[%ls] (%d)", sqlState, nativeError); + + // Convert wide string message to UTF-8 + int msgSize = WideCharToMultiByte(CP_UTF8, 0, message, -1, NULL, 0, NULL, NULL); + std::vector msgBuffer(msgSize); + WideCharToMultiByte(CP_UTF8, 0, message, -1, msgBuffer.data(), msgSize, NULL, NULL); + + // Create the tuple with converted strings + records.append(py::make_tuple( + py::str(stateWithError), + py::str(msgBuffer.data()) + )); +#else + // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 + std::string stateStr = WideToUTF8(SQLWCHARToWString(sqlState)); + std::string msgStr = WideToUTF8(SQLWCHARToWString(message, messageLen)); + + // Format the state string + std::string stateWithError = "[" + stateStr + "] (" + std::to_string(nativeError) + ")"; + + // Create the tuple with converted strings + records.append(py::make_tuple( + py::str(stateWithError), + py::str(msgStr) + )); +#endif + } + + return records; +} + // Wrap SQLExecDirect SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Query) { LOG("Execute SQL query directly - {}", Query.c_str()); @@ -909,6 +970,18 @@ SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Q DriverLoader::getInstance().loadDriver(); // Load the driver } + // Ensure statement is scrollable BEFORE executing + if (SQLSetStmtAttr_ptr && StatementHandle && StatementHandle->get()) { + SQLSetStmtAttr_ptr(StatementHandle->get(), + SQL_ATTR_CURSOR_TYPE, + (SQLPOINTER)SQL_CURSOR_STATIC, + 0); + SQLSetStmtAttr_ptr(StatementHandle->get(), + SQL_ATTR_CONCURRENCY, + (SQLPOINTER)SQL_CONCUR_READ_ONLY, + 0); + } + SQLWCHAR* queryPtr; #if defined(__APPLE__) || defined(__linux__) std::vector queryBuffer = WStringToSQLWCHAR(Query); @@ -923,6 +996,91 @@ SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Q return ret; } +// Wrapper for SQLTables +SQLRETURN SQLTables_wrap(SqlHandlePtr StatementHandle, + const std::wstring& catalog, + const std::wstring& schema, + const std::wstring& table, + const std::wstring& tableType) { + + if (!SQLTables_ptr) { + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); + } + + SQLWCHAR* catalogPtr = nullptr; + SQLWCHAR* schemaPtr = nullptr; + SQLWCHAR* tablePtr = nullptr; + SQLWCHAR* tableTypePtr = nullptr; + SQLSMALLINT catalogLen = 0; + SQLSMALLINT schemaLen = 0; + SQLSMALLINT tableLen = 0; + SQLSMALLINT tableTypeLen = 0; + + std::vector catalogBuffer; + std::vector schemaBuffer; + std::vector tableBuffer; + std::vector tableTypeBuffer; + +#if defined(__APPLE__) || defined(__linux__) + // On Unix platforms, convert wstring to SQLWCHAR array + if (!catalog.empty()) { + catalogBuffer = WStringToSQLWCHAR(catalog); + catalogPtr = catalogBuffer.data(); + catalogLen = SQL_NTS; + } + if (!schema.empty()) { + schemaBuffer = WStringToSQLWCHAR(schema); + schemaPtr = schemaBuffer.data(); + schemaLen = SQL_NTS; + } + if (!table.empty()) { + tableBuffer = WStringToSQLWCHAR(table); + tablePtr = tableBuffer.data(); + tableLen = SQL_NTS; + } + if (!tableType.empty()) { + tableTypeBuffer = WStringToSQLWCHAR(tableType); + tableTypePtr = tableTypeBuffer.data(); + tableTypeLen = SQL_NTS; + } +#else + // On Windows, direct assignment works + if (!catalog.empty()) { + catalogPtr = const_cast(catalog.c_str()); + catalogLen = SQL_NTS; + } + if (!schema.empty()) { + schemaPtr = const_cast(schema.c_str()); + schemaLen = SQL_NTS; + } + if (!table.empty()) { + tablePtr = const_cast(table.c_str()); + tableLen = SQL_NTS; + } + if (!tableType.empty()) { + tableTypePtr = const_cast(tableType.c_str()); + tableTypeLen = SQL_NTS; + } +#endif + + SQLRETURN ret = SQLTables_ptr( + StatementHandle->get(), + catalogPtr, catalogLen, + schemaPtr, schemaLen, + tablePtr, tableLen, + tableTypePtr, tableTypeLen + ); + + if (!SQL_SUCCEEDED(ret)) { + LOG("SQLTables failed with return code: {}", ret); + } else { + LOG("SQLTables succeeded"); + } + + return ret; +} + // Executes the provided query. If the query is parametrized, it prepares the statement and // binds the parameters. Otherwise, it executes the query directly. // 'usePrepare' parameter can be used to disable the prepare step for queries that might already @@ -948,6 +1106,19 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, if (!statementHandle || !statementHandle->get()) { LOG("Statement handle is null or empty"); } + + // Ensure statement is scrollable BEFORE executing + if (SQLSetStmtAttr_ptr && hStmt) { + SQLSetStmtAttr_ptr(hStmt, + SQL_ATTR_CURSOR_TYPE, + (SQLPOINTER)SQL_CURSOR_STATIC, + 0); + SQLSetStmtAttr_ptr(hStmt, + SQL_ATTR_CONCURRENCY, + (SQLPOINTER)SQL_CONCUR_READ_ONLY, + 0); + } + SQLWCHAR* queryPtr; #if defined(__APPLE__) || defined(__linux__) std::vector queryBuffer = WStringToSQLWCHAR(query); @@ -1817,6 +1988,20 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p return ret; } +SQLRETURN SQLFetchScroll_wrap(SqlHandlePtr StatementHandle, SQLSMALLINT FetchOrientation, SQLLEN FetchOffset, py::list& /*row_data*/) { + LOG("Fetching with scroll: orientation={}, offset={}", FetchOrientation, FetchOffset); + if (!SQLFetchScroll_ptr) { + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver + } + + // Perform scroll; do not fetch row data here + return SQLFetchScroll_ptr + ? SQLFetchScroll_ptr(StatementHandle->get(), FetchOrientation, FetchOffset) + : SQL_ERROR; +} + + // For column in the result set, binds a buffer to retrieve column data // TODO: Move to anonymous namespace, since it is not used outside this file SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& columnNames, @@ -2307,6 +2492,10 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch return ret; } + // Reset attributes before returning to avoid using stack pointers later + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)1, 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, NULL, 0); + return ret; } @@ -2396,6 +2585,10 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows) { return ret; } } + + // Reset attributes before returning to avoid using stack pointers later + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)1, 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, NULL, 0); return ret; } @@ -2553,6 +2746,16 @@ PYBIND11_MODULE(ddbc_bindings, m) { m.def("DDBCSQLFetchAll", &FetchAll_wrap, "Fetch all rows from the result set"); m.def("DDBCSQLFreeHandle", &SQLFreeHandle_wrap, "Free a handle"); m.def("DDBCSQLCheckError", &SQLCheckError_Wrap, "Check for driver errors"); + m.def("DDBCSQLGetAllDiagRecords", &SQLGetAllDiagRecords, + "Get all diagnostic records for a handle", + py::arg("handle")); + m.def("DDBCSQLTables", &SQLTables_wrap, + "Get table information using ODBC SQLTables", + py::arg("StatementHandle"), py::arg("catalog") = std::wstring(), + py::arg("schema") = std::wstring(), py::arg("table") = std::wstring(), + py::arg("tableType") = std::wstring()); + m.def("DDBCSQLFetchScroll", &SQLFetchScroll_wrap, + "Scroll to a specific position in the result set and optionally fetch data"); // Add a version attribute m.attr("__version__") = "1.0.0"; diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 22bc524bd..1bb3efb02 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -105,7 +105,18 @@ typedef SQLRETURN (SQL_API* SQLDescribeColFunc)(SQLHSTMT, SQLUSMALLINT, SQLWCHAR typedef SQLRETURN (SQL_API* SQLMoreResultsFunc)(SQLHSTMT); typedef SQLRETURN (SQL_API* SQLColAttributeFunc)(SQLHSTMT, SQLUSMALLINT, SQLUSMALLINT, SQLPOINTER, SQLSMALLINT, SQLSMALLINT*, SQLPOINTER); - +typedef SQLRETURN (*SQLTablesFunc)( + SQLHSTMT StatementHandle, + SQLWCHAR* CatalogName, + SQLSMALLINT NameLength1, + SQLWCHAR* SchemaName, + SQLSMALLINT NameLength2, + SQLWCHAR* TableName, + SQLSMALLINT NameLength3, + SQLWCHAR* TableType, + SQLSMALLINT NameLength4 +); + // Transaction APIs typedef SQLRETURN (SQL_API* SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT); @@ -148,6 +159,7 @@ extern SQLBindColFunc SQLBindCol_ptr; extern SQLDescribeColFunc SQLDescribeCol_ptr; extern SQLMoreResultsFunc SQLMoreResults_ptr; extern SQLColAttributeFunc SQLColAttribute_ptr; +extern SQLTablesFunc SQLTables_ptr; // Transaction APIs extern SQLEndTranFunc SQLEndTran_ptr; diff --git a/mssql_python/row.py b/mssql_python/row.py index 2c88412de..bbea7fdeb 100644 --- a/mssql_python/row.py +++ b/mssql_python/row.py @@ -9,27 +9,27 @@ class Row: print(row.column_name) # Access by column name """ - def __init__(self, values, cursor_description): + def __init__(self, values, description, column_map=None): """ - Initialize a Row object with values and cursor description. + Initialize a Row object with values and description. Args: - values: List of values for this row - cursor_description: The cursor description containing column metadata + values: List of values for this row. + description: Description of the columns (from cursor.description). + column_map: Optional mapping of column names to indices. """ self._values = values + self._description = description - # TODO: ADO task - Optimize memory usage by sharing column map across rows - # Instead of storing the full cursor_description in each Row object: - # 1. Build the column map once at the cursor level after setting description - # 2. Pass only this map to each Row instance - # 3. Remove cursor_description from Row objects entirely - - # Create mapping of column names to indices - self._column_map = {} - for i, desc in enumerate(cursor_description): - if desc and desc[0]: # Ensure column name exists - self._column_map[desc[0]] = i + # Build column map if not provided + if column_map is None: + self._column_map = {} + for i, desc in enumerate(description): + col_name = desc[0] + self._column_map[col_name] = i + self._column_map[col_name.lower()] = i # Add lowercase for case-insensitivity + else: + self._column_map = column_map def __getitem__(self, index): """Allow accessing by numeric index: row[0]""" diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 2fc09e73d..78a96b795 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -1302,7 +1302,7 @@ def test_row_column_mapping(cursor, db_connection): assert getattr(row, "Complex Name!") == 42, "Complex column name access failed" # Test column map completeness - assert len(row._column_map) == 3, "Column map size incorrect" + assert len(row._column_map) >= 3, "Column map size incorrect" assert "FirstColumn" in row._column_map, "Column map missing CamelCase column" assert "Second_Column" in row._column_map, "Column map missing snake_case column" assert "Complex Name!" in row._column_map, "Column map missing complex name column" @@ -2820,6 +2820,1954 @@ def test_fetchval_performance_common_patterns(cursor, db_connection): except: pass +def test_cursor_commit_basic(cursor, db_connection): + """Test basic cursor commit functionality""" + try: + # Set autocommit to False to test manual commit + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_cursor_commit") + cursor.execute("CREATE TABLE #pytest_cursor_commit (id INTEGER, name VARCHAR(50))") + cursor.commit() # Commit table creation + + # Insert data using cursor + cursor.execute("INSERT INTO #pytest_cursor_commit VALUES (1, 'test1')") + cursor.execute("INSERT INTO #pytest_cursor_commit VALUES (2, 'test2')") + + # Before commit, data should still be visible in same transaction + cursor.execute("SELECT COUNT(*) FROM #pytest_cursor_commit") + count = cursor.fetchval() + assert count == 2, "Data should be visible before commit in same transaction" + + # Commit using cursor + cursor.commit() + + # Verify data is committed + cursor.execute("SELECT COUNT(*) FROM #pytest_cursor_commit") + count = cursor.fetchval() + assert count == 2, "Data should be committed and visible" + + # Verify specific data + cursor.execute("SELECT name FROM #pytest_cursor_commit ORDER BY id") + rows = cursor.fetchall() + assert len(rows) == 2, "Should have 2 rows after commit" + assert rows[0][0] == 'test1', "First row should be 'test1'" + assert rows[1][0] == 'test2', "Second row should be 'test2'" + + except Exception as e: + pytest.fail(f"Cursor commit basic test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_cursor_commit") + cursor.commit() + except: + pass + +def test_cursor_rollback_basic(cursor, db_connection): + """Test basic cursor rollback functionality""" + try: + # Set autocommit to False to test manual rollback + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_cursor_rollback") + cursor.execute("CREATE TABLE #pytest_cursor_rollback (id INTEGER, name VARCHAR(50))") + cursor.commit() # Commit table creation + + # Insert initial data and commit + cursor.execute("INSERT INTO #pytest_cursor_rollback VALUES (1, 'permanent')") + cursor.commit() + + # Insert more data but don't commit + cursor.execute("INSERT INTO #pytest_cursor_rollback VALUES (2, 'temp1')") + cursor.execute("INSERT INTO #pytest_cursor_rollback VALUES (3, 'temp2')") + + # Before rollback, data should be visible in same transaction + cursor.execute("SELECT COUNT(*) FROM #pytest_cursor_rollback") + count = cursor.fetchval() + assert count == 3, "All data should be visible before rollback in same transaction" + + # Rollback using cursor + cursor.rollback() + + # Verify only committed data remains + cursor.execute("SELECT COUNT(*) FROM #pytest_cursor_rollback") + count = cursor.fetchval() + assert count == 1, "Only committed data should remain after rollback" + + # Verify specific data + cursor.execute("SELECT name FROM #pytest_cursor_rollback") + row = cursor.fetchone() + assert row[0] == 'permanent', "Only the committed row should remain" + + except Exception as e: + pytest.fail(f"Cursor rollback basic test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_cursor_rollback") + cursor.commit() + except: + pass + +def test_cursor_commit_affects_all_cursors(db_connection): + """Test that cursor commit affects all cursors on the same connection""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create two cursors + cursor1 = db_connection.cursor() + cursor2 = db_connection.cursor() + + # Create test table using cursor1 + drop_table_if_exists(cursor1, "#pytest_multi_cursor") + cursor1.execute("CREATE TABLE #pytest_multi_cursor (id INTEGER, source VARCHAR(10))") + cursor1.commit() # Commit table creation + + # Insert data using cursor1 + cursor1.execute("INSERT INTO #pytest_multi_cursor VALUES (1, 'cursor1')") + + # Insert data using cursor2 + cursor2.execute("INSERT INTO #pytest_multi_cursor VALUES (2, 'cursor2')") + + # Both cursors should see both inserts before commit + cursor1.execute("SELECT COUNT(*) FROM #pytest_multi_cursor") + count1 = cursor1.fetchval() + cursor2.execute("SELECT COUNT(*) FROM #pytest_multi_cursor") + count2 = cursor2.fetchval() + assert count1 == 2, "Cursor1 should see both inserts" + assert count2 == 2, "Cursor2 should see both inserts" + + # Commit using cursor1 (should affect both cursors) + cursor1.commit() + + # Both cursors should still see the committed data + cursor1.execute("SELECT COUNT(*) FROM #pytest_multi_cursor") + count1 = cursor1.fetchval() + cursor2.execute("SELECT COUNT(*) FROM #pytest_multi_cursor") + count2 = cursor2.fetchval() + assert count1 == 2, "Cursor1 should see committed data" + assert count2 == 2, "Cursor2 should see committed data" + + # Verify data content + cursor1.execute("SELECT source FROM #pytest_multi_cursor ORDER BY id") + rows = cursor1.fetchall() + assert rows[0][0] == 'cursor1', "First row should be from cursor1" + assert rows[1][0] == 'cursor2', "Second row should be from cursor2" + + except Exception as e: + pytest.fail(f"Multi-cursor commit test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor1.execute("DROP TABLE #pytest_multi_cursor") + cursor1.commit() + cursor1.close() + cursor2.close() + except: + pass + +def test_cursor_rollback_affects_all_cursors(db_connection): + """Test that cursor rollback affects all cursors on the same connection""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create two cursors + cursor1 = db_connection.cursor() + cursor2 = db_connection.cursor() + + # Create test table and insert initial data + drop_table_if_exists(cursor1, "#pytest_multi_rollback") + cursor1.execute("CREATE TABLE #pytest_multi_rollback (id INTEGER, source VARCHAR(10))") + cursor1.execute("INSERT INTO #pytest_multi_rollback VALUES (0, 'baseline')") + cursor1.commit() # Commit initial state + + # Insert data using both cursors + cursor1.execute("INSERT INTO #pytest_multi_rollback VALUES (1, 'cursor1')") + cursor2.execute("INSERT INTO #pytest_multi_rollback VALUES (2, 'cursor2')") + + # Both cursors should see all data before rollback + cursor1.execute("SELECT COUNT(*) FROM #pytest_multi_rollback") + count1 = cursor1.fetchval() + cursor2.execute("SELECT COUNT(*) FROM #pytest_multi_rollback") + count2 = cursor2.fetchval() + assert count1 == 3, "Cursor1 should see all data before rollback" + assert count2 == 3, "Cursor2 should see all data before rollback" + + # Rollback using cursor2 (should affect both cursors) + cursor2.rollback() + + # Both cursors should only see the initial committed data + cursor1.execute("SELECT COUNT(*) FROM #pytest_multi_rollback") + count1 = cursor1.fetchval() + cursor2.execute("SELECT COUNT(*) FROM #pytest_multi_rollback") + count2 = cursor2.fetchval() + assert count1 == 1, "Cursor1 should only see committed data after rollback" + assert count2 == 1, "Cursor2 should only see committed data after rollback" + + # Verify only initial data remains + cursor1.execute("SELECT source FROM #pytest_multi_rollback") + row = cursor1.fetchone() + assert row[0] == 'baseline', "Only the committed row should remain" + + except Exception as e: + pytest.fail(f"Multi-cursor rollback test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor1.execute("DROP TABLE #pytest_multi_rollback") + cursor1.commit() + cursor1.close() + cursor2.close() + except: + pass + +def test_cursor_commit_closed_cursor(db_connection): + """Test cursor commit on closed cursor should raise exception""" + try: + cursor = db_connection.cursor() + cursor.close() + + with pytest.raises(Exception) as exc_info: + cursor.commit() + + assert "closed" in str(exc_info.value).lower(), "commit on closed cursor should raise exception mentioning cursor is closed" + + except Exception as e: + if "closed" not in str(e).lower(): + pytest.fail(f"Cursor commit closed cursor test failed: {e}") + +def test_cursor_rollback_closed_cursor(db_connection): + """Test cursor rollback on closed cursor should raise exception""" + try: + cursor = db_connection.cursor() + cursor.close() + + with pytest.raises(Exception) as exc_info: + cursor.rollback() + + assert "closed" in str(exc_info.value).lower(), "rollback on closed cursor should raise exception mentioning cursor is closed" + + except Exception as e: + if "closed" not in str(e).lower(): + pytest.fail(f"Cursor rollback closed cursor test failed: {e}") + +def test_cursor_commit_equivalent_to_connection_commit(cursor, db_connection): + """Test that cursor.commit() is equivalent to connection.commit()""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_commit_equiv") + cursor.execute("CREATE TABLE #pytest_commit_equiv (id INTEGER, method VARCHAR(20))") + cursor.commit() + + # Test 1: Use cursor.commit() + cursor.execute("INSERT INTO #pytest_commit_equiv VALUES (1, 'cursor_commit')") + cursor.commit() + + # Verify the chained operation worked + result = cursor.execute("SELECT method FROM #pytest_commit_equiv WHERE id = 1").fetchval() + assert result == 'cursor_commit', "Method chaining with commit should work" + + # Test 2: Use connection.commit() + cursor.execute("INSERT INTO #pytest_commit_equiv VALUES (2, 'conn_commit')") + db_connection.commit() + + cursor.execute("SELECT method FROM #pytest_commit_equiv WHERE id = 2") + result = cursor.fetchone() + assert result[0] == 'conn_commit', "Should return 'conn_commit'" + + # Test 3: Mix both methods + cursor.execute("INSERT INTO #pytest_commit_equiv VALUES (3, 'mixed1')") + cursor.commit() # Use cursor + cursor.execute("INSERT INTO #pytest_commit_equiv VALUES (4, 'mixed2')") + db_connection.commit() # Use connection + + cursor.execute("SELECT method FROM #pytest_commit_equiv ORDER BY id") + rows = cursor.fetchall() + assert len(rows) == 4, "Should have 4 rows after mixed commits" + assert rows[0][0] == 'cursor_commit', "First row should be 'cursor_commit'" + assert rows[1][0] == 'conn_commit', "Second row should be 'conn_commit'" + + except Exception as e: + pytest.fail(f"Cursor commit equivalence test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_commit_equiv") + cursor.commit() + except: + pass + +def test_cursor_transaction_boundary_behavior(cursor, db_connection): + """Test cursor commit/rollback behavior at transaction boundaries""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_transaction") + cursor.execute("CREATE TABLE #pytest_transaction (id INTEGER, step VARCHAR(20))") + cursor.commit() + + # Transaction 1: Insert and commit + cursor.execute("INSERT INTO #pytest_transaction VALUES (1, 'step1')") + cursor.commit() + + # Transaction 2: Insert, rollback, then insert different data and commit + cursor.execute("INSERT INTO #pytest_transaction VALUES (2, 'temp')") + cursor.rollback() # This should rollback the temp insert + + cursor.execute("INSERT INTO #pytest_transaction VALUES (2, 'step2')") + cursor.commit() + + # Verify final state + cursor.execute("SELECT step FROM #pytest_transaction ORDER BY id") + rows = cursor.fetchall() + assert len(rows) == 2, "Should have 2 rows" + assert rows[0][0] == 'step1', "First row should be step1" + assert rows[1][0] == 'step2', "Second row should be step2 (not temp)" + + # Transaction 3: Multiple operations with rollback + cursor.execute("INSERT INTO #pytest_transaction VALUES (3, 'temp1')") + cursor.execute("INSERT INTO #pytest_transaction VALUES (4, 'temp2')") + cursor.execute("DELETE FROM #pytest_transaction WHERE id = 1") + cursor.rollback() # Rollback all operations in transaction 3 + + # Verify rollback worked + cursor.execute("SELECT COUNT(*) FROM #pytest_transaction") + count = cursor.fetchval() + assert count == 2, "Rollback should restore previous state" + + cursor.execute("SELECT id FROM #pytest_transaction ORDER BY id") + rows = cursor.fetchall() + assert rows[0][0] == 1, "Row 1 should still exist after rollback" + assert rows[1][0] == 2, "Row 2 should still exist after rollback" + + except Exception as e: + pytest.fail(f"Transaction boundary behavior test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_transaction") + cursor.commit() + except: + pass + +def test_cursor_commit_with_method_chaining(cursor, db_connection): + """Test cursor commit in method chaining scenarios""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_chaining") + cursor.execute("CREATE TABLE #pytest_chaining (id INTEGER, value VARCHAR(20))") + cursor.commit() + + # Test method chaining with execute and commit + cursor.execute("INSERT INTO #pytest_chaining VALUES (1, 'chained')") + cursor.commit() + + # Verify the chained operation worked + result = cursor.execute("SELECT value FROM #pytest_chaining WHERE id = 1").fetchval() + assert result == 'chained', "Method chaining with commit should work" + + # Verify rollback worked + count = cursor.execute("SELECT COUNT(*) FROM #pytest_chaining").fetchval() + assert count == 1, "Rollback after chained operations should work" + + except Exception as e: + pytest.fail(f"Cursor commit method chaining test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_chaining") + cursor.commit() + except: + pass + +def test_cursor_commit_error_scenarios(cursor, db_connection): + """Test cursor commit error scenarios and recovery""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_commit_errors") + cursor.execute("CREATE TABLE #pytest_commit_errors (id INTEGER PRIMARY KEY, value VARCHAR(20))") + cursor.commit() + + # Insert valid data + cursor.execute("INSERT INTO #pytest_commit_errors VALUES (1, 'valid')") + cursor.commit() + + # Try to insert duplicate key (should fail) + try: + cursor.execute("INSERT INTO #pytest_commit_errors VALUES (1, 'duplicate')") + cursor.commit() # This might succeed depending on when the constraint is checked + pytest.fail("Expected constraint violation") + except Exception: + # Expected - constraint violation + cursor.rollback() # Clean up the failed transaction + + # Verify we can still use the cursor after error and rollback + cursor.execute("INSERT INTO #pytest_commit_errors VALUES (2, 'after_error')") + cursor.commit() + + cursor.execute("SELECT COUNT(*) FROM #pytest_commit_errors") + count = cursor.fetchval() + assert count == 2, "Should have 2 rows after error recovery" + + # Verify data integrity + cursor.execute("SELECT value FROM #pytest_commit_errors ORDER BY id") + rows = cursor.fetchall() + assert rows[0][0] == 'valid', "First row should be unchanged" + assert rows[1][0] == 'after_error', "Second row should be the recovery insert" + + except Exception as e: + pytest.fail(f"Cursor commit error scenarios test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_commit_errors") + cursor.commit() + except: + pass + +def test_cursor_commit_performance_patterns(cursor, db_connection): + """Test cursor commit with performance-related patterns""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_commit_perf") + cursor.execute("CREATE TABLE #pytest_commit_perf (id INTEGER, batch_num INTEGER)") + cursor.commit() + + # Test batch insert with periodic commits + batch_size = 5 + total_records = 15 + + for i in range(total_records): + batch_num = i // batch_size + cursor.execute("INSERT INTO #pytest_commit_perf VALUES (?, ?)", i, batch_num) + + # Commit every batch_size records + if (i + 1) % batch_size == 0: + cursor.commit() + + # Commit any remaining records + cursor.commit() + + # Verify all records were inserted + cursor.execute("SELECT COUNT(*) FROM #pytest_commit_perf") + count = cursor.fetchval() + assert count == total_records, f"Should have {total_records} records" + + # Verify batch distribution + cursor.execute("SELECT batch_num, COUNT(*) FROM #pytest_commit_perf GROUP BY batch_num ORDER BY batch_num") + batches = cursor.fetchall() + assert len(batches) == 3, "Should have 3 batches" + assert batches[0][1] == 5, "First batch should have 5 records" + assert batches[1][1] == 5, "Second batch should have 5 records" + assert batches[2][1] == 5, "Third batch should have 5 records" + + except Exception as e: + pytest.fail(f"Cursor commit performance patterns test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_commit_perf") + cursor.commit() + except: + pass + +def test_cursor_rollback_error_scenarios(cursor, db_connection): + """Test cursor rollback error scenarios and recovery""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_rollback_errors") + cursor.execute("CREATE TABLE #pytest_rollback_errors (id INTEGER PRIMARY KEY, value VARCHAR(20))") + cursor.commit() + + # Insert valid data and commit + cursor.execute("INSERT INTO #pytest_rollback_errors VALUES (1, 'committed')") + cursor.commit() + + # Start a transaction with multiple operations + cursor.execute("INSERT INTO #pytest_rollback_errors VALUES (2, 'temp1')") + cursor.execute("INSERT INTO #pytest_rollback_errors VALUES (3, 'temp2')") + cursor.execute("UPDATE #pytest_rollback_errors SET value = 'modified' WHERE id = 1") + + # Verify uncommitted changes are visible within transaction + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_errors") + count = cursor.fetchval() + assert count == 3, "Should see all uncommitted changes within transaction" + + cursor.execute("SELECT value FROM #pytest_rollback_errors WHERE id = 1") + modified_value = cursor.fetchval() + assert modified_value == 'modified', "Should see uncommitted modification" + + # Rollback the transaction + cursor.rollback() + + # Verify rollback restored original state + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_errors") + count = cursor.fetchval() + assert count == 1, "Should only have committed data after rollback" + + cursor.execute("SELECT value FROM #pytest_rollback_errors WHERE id = 1") + original_value = cursor.fetchval() + assert original_value == 'committed', "Original value should be restored after rollback" + + # Verify cursor is still usable after rollback + cursor.execute("INSERT INTO #pytest_rollback_errors VALUES (4, 'after_rollback')") + cursor.commit() + + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_errors") + count = cursor.fetchval() + assert count == 2, "Should have 2 rows after recovery" + + # Verify data integrity + cursor.execute("SELECT value FROM #pytest_rollback_errors ORDER BY id") + rows = cursor.fetchall() + assert rows[0][0] == 'committed', "First row should be unchanged" + assert rows[1][0] == 'after_rollback', "Second row should be the recovery insert" + + except Exception as e: + pytest.fail(f"Cursor rollback error scenarios test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_rollback_errors") + cursor.commit() + except: + pass + +def test_cursor_rollback_with_method_chaining(cursor, db_connection): + """Test cursor rollback in method chaining scenarios""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_rollback_chaining") + cursor.execute("CREATE TABLE #pytest_rollback_chaining (id INTEGER, value VARCHAR(20))") + cursor.commit() + + # Insert initial committed data + cursor.execute("INSERT INTO #pytest_rollback_chaining VALUES (1, 'permanent')") + cursor.commit() + + # Test method chaining with execute and rollback + cursor.execute("INSERT INTO #pytest_rollback_chaining VALUES (2, 'temporary')") + + # Verify temporary data is visible before rollback + result = cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_chaining").fetchval() + assert result == 2, "Should see temporary data before rollback" + + # Rollback the temporary insert + cursor.rollback() + + # Verify rollback worked with method chaining + count = cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_chaining").fetchval() + assert count == 1, "Should only have permanent data after rollback" + + # Test chaining after rollback + value = cursor.execute("SELECT value FROM #pytest_rollback_chaining WHERE id = 1").fetchval() + assert value == 'permanent', "Method chaining should work after rollback" + + except Exception as e: + pytest.fail(f"Cursor rollback method chaining test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_rollback_chaining") + cursor.commit() + except: + pass + +def test_cursor_rollback_savepoints_simulation(cursor, db_connection): + """Test cursor rollback with simulated savepoint behavior""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_rollback_savepoints") + cursor.execute("CREATE TABLE #pytest_rollback_savepoints (id INTEGER, stage VARCHAR(20))") + cursor.commit() + + # Stage 1: Insert and commit (simulated savepoint) + cursor.execute("INSERT INTO #pytest_rollback_savepoints VALUES (1, 'stage1')") + cursor.commit() + + # Stage 2: Insert more data but don't commit + cursor.execute("INSERT INTO #pytest_rollback_savepoints VALUES (2, 'stage2')") + cursor.execute("INSERT INTO #pytest_rollback_savepoints VALUES (3, 'stage2')") + + # Verify stage 2 data is visible + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_savepoints WHERE stage = 'stage2'") + stage2_count = cursor.fetchval() + assert stage2_count == 2, "Should see stage 2 data before rollback" + + # Rollback stage 2 (back to stage 1) + cursor.rollback() + + # Verify only stage 1 data remains + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_savepoints") + total_count = cursor.fetchval() + assert total_count == 1, "Should only have stage 1 data after rollback" + + cursor.execute("SELECT stage FROM #pytest_rollback_savepoints") + remaining_stage = cursor.fetchval() + assert remaining_stage == 'stage1', "Should only have stage 1 data" + + # Stage 3: Try different operations and rollback + cursor.execute("INSERT INTO #pytest_rollback_savepoints VALUES (4, 'stage3')") + cursor.execute("UPDATE #pytest_rollback_savepoints SET stage = 'modified' WHERE id = 1") + cursor.execute("INSERT INTO #pytest_rollback_savepoints VALUES (5, 'stage3')") + + # Verify stage 3 changes + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_savepoints") + stage3_count = cursor.fetchval() + assert stage3_count == 3, "Should see all stage 3 changes" + + # Rollback stage 3 + cursor.rollback() + + # Verify back to stage 1 + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_savepoints") + final_count = cursor.fetchval() + assert final_count == 1, "Should be back to stage 1 after second rollback" + + cursor.execute("SELECT stage FROM #pytest_rollback_savepoints WHERE id = 1") + final_stage = cursor.fetchval() + assert final_stage == 'stage1', "Stage 1 data should be unmodified" + + except Exception as e: + pytest.fail(f"Cursor rollback savepoints simulation test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_rollback_savepoints") + cursor.commit() + except: + pass + +def test_cursor_rollback_performance_patterns(cursor, db_connection): + """Test cursor rollback with performance-related patterns""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_rollback_perf") + cursor.execute("CREATE TABLE #pytest_rollback_perf (id INTEGER, batch_num INTEGER, status VARCHAR(10))") + cursor.commit() + + # Simulate batch processing with selective rollback + batch_size = 5 + total_batches = 3 + + for batch_num in range(total_batches): + try: + # Process a batch + for i in range(batch_size): + record_id = batch_num * batch_size + i + 1 + + # Simulate some records failing based on business logic + if batch_num == 1 and i >= 3: # Simulate failure in batch 1 + cursor.execute("INSERT INTO #pytest_rollback_perf VALUES (?, ?, ?)", + record_id, batch_num, 'error') + # Simulate error condition + raise Exception(f"Simulated error in batch {batch_num}") + else: + cursor.execute("INSERT INTO #pytest_rollback_perf VALUES (?, ?, ?)", + record_id, batch_num, 'success') + + # If batch completed successfully, commit + cursor.commit() + print(f"Batch {batch_num} committed successfully") + + except Exception as e: + # If batch failed, rollback + cursor.rollback() + print(f"Batch {batch_num} rolled back due to: {e}") + + # Verify only successful batches were committed + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_perf") + total_count = cursor.fetchval() + assert total_count == 10, "Should have 10 records (2 successful batches of 5 each)" + + # Verify batch distribution + cursor.execute("SELECT batch_num, COUNT(*) FROM #pytest_rollback_perf GROUP BY batch_num ORDER BY batch_num") + batches = cursor.fetchall() + assert len(batches) == 2, "Should have 2 successful batches" + assert batches[0][0] == 0 and batches[0][1] == 5, "Batch 0 should have 5 records" + assert batches[1][0] == 2 and batches[1][1] == 5, "Batch 2 should have 5 records" + + # Verify no error records exist (they were rolled back) + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_perf WHERE status = 'error'") + error_count = cursor.fetchval() + assert error_count == 0, "No error records should exist after rollbacks" + + except Exception as e: + pytest.fail(f"Cursor rollback performance patterns test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_rollback_perf") + cursor.commit() + except: + pass + +def test_cursor_rollback_equivalent_to_connection_rollback(cursor, db_connection): + """Test that cursor.rollback() is equivalent to connection.rollback()""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_rollback_equiv") + cursor.execute("CREATE TABLE #pytest_rollback_equiv (id INTEGER, method VARCHAR(20))") + cursor.commit() + + # Test 1: Use cursor.rollback() + cursor.execute("INSERT INTO #pytest_rollback_equiv VALUES (1, 'cursor_rollback')") + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_equiv") + count = cursor.fetchval() + assert count == 1, "Data should be visible before rollback" + + cursor.rollback() # Use cursor.rollback() + + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_equiv") + count = cursor.fetchval() + assert count == 0, "Data should be rolled back via cursor.rollback()" + + # Test 2: Use connection.rollback() + cursor.execute("INSERT INTO #pytest_rollback_equiv VALUES (2, 'conn_rollback')") + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_equiv") + count = cursor.fetchval() + assert count == 1, "Data should be visible before rollback" + + db_connection.rollback() # Use connection.rollback() + + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_equiv") + count = cursor.fetchval() + assert count == 0, "Data should be rolled back via connection.rollback()" + + # Test 3: Mix both methods + cursor.execute("INSERT INTO #pytest_rollback_equiv VALUES (3, 'mixed1')") + cursor.rollback() # Use cursor + + cursor.execute("INSERT INTO #pytest_rollback_equiv VALUES (4, 'mixed2')") + db_connection.rollback() # Use connection + + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_equiv") + count = cursor.fetchval() + assert count == 0, "Both rollback methods should work equivalently" + + # Test 4: Verify both commit and rollback work together + cursor.execute("INSERT INTO #pytest_rollback_equiv VALUES (5, 'final_test')") + cursor.commit() # Commit this one + + cursor.execute("INSERT INTO #pytest_rollback_equiv VALUES (6, 'temp')") + cursor.rollback() # Rollback this one + + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_equiv") + count = cursor.fetchval() + assert count == 1, "Should have only the committed record" + + cursor.execute("SELECT method FROM #pytest_rollback_equiv") + method = cursor.fetchval() + assert method == 'final_test', "Should have the committed record" + + except Exception as e: + pytest.fail(f"Cursor rollback equivalence test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_rollback_equiv") + cursor.commit() + except: + pass + +def test_cursor_rollback_nested_transactions_simulation(cursor, db_connection): + """Test cursor rollback with simulated nested transaction behavior""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_rollback_nested") + cursor.execute("CREATE TABLE #pytest_rollback_nested (id INTEGER, level VARCHAR(20), operation VARCHAR(20))") + cursor.commit() + + # Outer transaction level + cursor.execute("INSERT INTO #pytest_rollback_nested VALUES (1, 'outer', 'insert')") + cursor.execute("INSERT INTO #pytest_rollback_nested VALUES (2, 'outer', 'insert')") + + # Verify outer level data + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_nested WHERE level = 'outer'") + outer_count = cursor.fetchval() + assert outer_count == 2, "Should have 2 outer level records" + + # Simulate inner transaction + cursor.execute("INSERT INTO #pytest_rollback_nested VALUES (3, 'inner', 'insert')") + cursor.execute("UPDATE #pytest_rollback_nested SET operation = 'updated' WHERE level = 'outer' AND id = 1") + cursor.execute("INSERT INTO #pytest_rollback_nested VALUES (4, 'inner', 'insert')") + + # Verify inner changes are visible + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_nested") + total_count = cursor.fetchval() + assert total_count == 4, "Should see all records including inner changes" + + cursor.execute("SELECT operation FROM #pytest_rollback_nested WHERE id = 1") + updated_op = cursor.fetchval() + assert updated_op == 'updated', "Should see updated operation" + + # Rollback everything (simulating inner transaction failure affecting outer) + cursor.rollback() + + # Verify complete rollback + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_nested") + final_count = cursor.fetchval() + assert final_count == 0, "All changes should be rolled back" + + # Test successful nested-like pattern + # Outer level + cursor.execute("INSERT INTO #pytest_rollback_nested VALUES (1, 'outer', 'insert')") + cursor.commit() # Commit outer level + + # Inner level + cursor.execute("INSERT INTO #pytest_rollback_nested VALUES (2, 'inner', 'insert')") + cursor.execute("INSERT INTO #pytest_rollback_nested VALUES (3, 'inner', 'insert')") + cursor.rollback() # Rollback only inner level + + # Verify only outer level remains + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_nested") + remaining_count = cursor.fetchval() + assert remaining_count == 1, "Should only have committed outer level data" + + cursor.execute("SELECT level FROM #pytest_rollback_nested") + remaining_level = cursor.fetchval() + assert remaining_level == 'outer', "Should only have outer level record" + + except Exception as e: + pytest.fail(f"Cursor rollback nested transactions test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_rollback_nested") + cursor.commit() + except: + pass + +def test_cursor_rollback_data_consistency(cursor, db_connection): + """Test cursor rollback maintains data consistency""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create related tables to test referential integrity + drop_table_if_exists(cursor, "#pytest_rollback_orders") + drop_table_if_exists(cursor, "#pytest_rollback_customers") + + cursor.execute(""" + CREATE TABLE #pytest_rollback_customers ( + id INTEGER PRIMARY KEY, + name VARCHAR(50) + ) + """) + + cursor.execute(""" + CREATE TABLE #pytest_rollback_orders ( + id INTEGER PRIMARY KEY, + customer_id INTEGER, + amount DECIMAL(10,2), + FOREIGN KEY (customer_id) REFERENCES #pytest_rollback_customers(id) + ) + """) + cursor.commit() + + # Insert initial data + cursor.execute("INSERT INTO #pytest_rollback_customers VALUES (1, 'John Doe')") + cursor.execute("INSERT INTO #pytest_rollback_customers VALUES (2, 'Jane Smith')") + cursor.commit() + + # Start transaction with multiple related operations + cursor.execute("INSERT INTO #pytest_rollback_customers VALUES (3, 'Bob Wilson')") + cursor.execute("INSERT INTO #pytest_rollback_orders VALUES (1, 1, 100.00)") + cursor.execute("INSERT INTO #pytest_rollback_orders VALUES (2, 2, 200.00)") + cursor.execute("INSERT INTO #pytest_rollback_orders VALUES (3, 3, 300.00)") + cursor.execute("UPDATE #pytest_rollback_customers SET name = 'John Updated' WHERE id = 1") + + # Verify uncommitted changes + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_customers") + customer_count = cursor.fetchval() + assert customer_count == 3, "Should have 3 customers before rollback" + + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_orders") + order_count = cursor.fetchval() + assert order_count == 3, "Should have 3 orders before rollback" + + cursor.execute("SELECT name FROM #pytest_rollback_customers WHERE id = 1") + updated_name = cursor.fetchval() + assert updated_name == 'John Updated', "Should see updated name" + + # Rollback all changes + cursor.rollback() + + # Verify data consistency after rollback + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_customers") + final_customer_count = cursor.fetchval() + assert final_customer_count == 2, "Should have original 2 customers after rollback" + + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_orders") + final_order_count = cursor.fetchval() + assert final_order_count == 0, "Should have no orders after rollback" + + cursor.execute("SELECT name FROM #pytest_rollback_customers WHERE id = 1") + original_name = cursor.fetchval() + assert original_name == 'John Doe', "Should have original name after rollback" + + # Verify referential integrity is maintained + cursor.execute("SELECT name FROM #pytest_rollback_customers ORDER BY id") + names = cursor.fetchall() + assert len(names) == 2, "Should have exactly 2 customers" + assert names[0][0] == 'John Doe', "First customer should be John Doe" + assert names[1][0] == 'Jane Smith', "Second customer should be Jane Smith" + + except Exception as e: + pytest.fail(f"Cursor rollback data consistency test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_rollback_orders") + cursor.execute("DROP TABLE #pytest_rollback_customers") + cursor.commit() + except: + pass + +def test_cursor_rollback_large_transaction(cursor, db_connection): + """Test cursor rollback with large transaction""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_rollback_large") + cursor.execute("CREATE TABLE #pytest_rollback_large (id INTEGER, data VARCHAR(100))") + cursor.commit() + + # Insert committed baseline data + cursor.execute("INSERT INTO #pytest_rollback_large VALUES (0, 'baseline')") + cursor.commit() + + # Start large transaction + large_transaction_size = 100 + + for i in range(1, large_transaction_size + 1): + cursor.execute("INSERT INTO #pytest_rollback_large VALUES (?, ?)", + i, f'large_transaction_data_{i}') + + # Add some updates to make transaction more complex + if i % 10 == 0: + cursor.execute("UPDATE #pytest_rollback_large SET data = ? WHERE id = ?", + f'updated_data_{i}', i) + + # Verify large transaction data is visible + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_large") + total_count = cursor.fetchval() + assert total_count == large_transaction_size + 1, f"Should have {large_transaction_size + 1} records before rollback" + + # Verify some updated data + cursor.execute("SELECT data FROM #pytest_rollback_large WHERE id = 10") + updated_data = cursor.fetchval() + assert updated_data == 'updated_data_10', "Should see updated data" + + # Rollback the large transaction + cursor.rollback() + + # Verify rollback worked + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_large") + final_count = cursor.fetchval() + assert final_count == 1, "Should only have baseline data after rollback" + + cursor.execute("SELECT data FROM #pytest_rollback_large WHERE id = 0") + baseline_data = cursor.fetchval() + assert baseline_data == 'baseline', "Baseline data should be unchanged" + + # Verify no large transaction data remains + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_large WHERE id > 0") + large_data_count = cursor.fetchval() + assert large_data_count == 0, "No large transaction data should remain" + + except Exception as e: + pytest.fail(f"Cursor rollback large transaction test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_rollback_large") + cursor.commit() + except: + pass + +# Helper for these scroll tests to avoid name collisions with other helpers +def _drop_if_exists_scroll(cursor, name): + try: + cursor.execute(f"DROP TABLE {name}") + cursor.commit() + except Exception: + pass + + +def test_scroll_relative_basic(cursor, db_connection): + """Relative scroll should advance by the given offset and update rownumber.""" + try: + _drop_if_exists_scroll(cursor, "#t_scroll_rel") + cursor.execute("CREATE TABLE #t_scroll_rel (id INTEGER)") + cursor.executemany("INSERT INTO #t_scroll_rel VALUES (?)", [(i,) for i in range(1, 11)]) + db_connection.commit() + + cursor.execute("SELECT id FROM #t_scroll_rel ORDER BY id") + # from fresh result set, skip 3 rows -> last-returned index becomes 2 (0-based) + cursor.scroll(3) + assert cursor.rownumber == 2, "After scroll(3) last-returned index should be 2" + + # Fetch current row to verify position: next fetch should return id=4 + row = cursor.fetchone() + assert row[0] == 4, "After scroll(3) the next fetch should return id=4" + # after fetch, last-returned index advances to 3 + assert cursor.rownumber == 3, "After fetchone(), last-returned index should be 3" + + finally: + _drop_if_exists_scroll(cursor, "#t_scroll_rel") + + +def test_scroll_absolute_basic(cursor, db_connection): + """Absolute scroll should position so the next fetch returns the requested index.""" + try: + _drop_if_exists_scroll(cursor, "#t_scroll_abs") + cursor.execute("CREATE TABLE #t_scroll_abs (id INTEGER)") + cursor.executemany("INSERT INTO #t_scroll_abs VALUES (?)", [(i,) for i in range(1, 8)]) + db_connection.commit() + + cursor.execute("SELECT id FROM #t_scroll_abs ORDER BY id") + + # absolute position 0 -> set last-returned index to 0 (position BEFORE fetch) + cursor.scroll(0, "absolute") + assert cursor.rownumber == 0, "After absolute(0) rownumber should be 0 (positioned at index 0)" + row = cursor.fetchone() + assert row[0] == 1, "At absolute position 0, fetch should return first row" + # after fetch, last-returned index remains 0 (implementation sets to last returned row) + assert cursor.rownumber == 0, "After fetch at absolute(0), last-returned index should be 0" + + # absolute position 3 -> next fetch should return id=4 + cursor.scroll(3, "absolute") + assert cursor.rownumber == 3, "After absolute(3) rownumber should be 3" + row = cursor.fetchone() + assert row[0] == 4, "At absolute position 3, should fetch row with id=4" + + finally: + _drop_if_exists_scroll(cursor, "#t_scroll_abs") + + +def test_scroll_backward_not_supported(cursor, db_connection): + """Backward scrolling must raise NotSupportedError for negative relative; absolute to same or forward allowed.""" + from mssql_python.exceptions import NotSupportedError + try: + _drop_if_exists_scroll(cursor, "#t_scroll_back") + cursor.execute("CREATE TABLE #t_scroll_back (id INTEGER)") + cursor.executemany("INSERT INTO #t_scroll_back VALUES (?)", [(1,), (2,), (3,)]) + db_connection.commit() + + cursor.execute("SELECT id FROM #t_scroll_back ORDER BY id") + + # move forward 1 (relative) + cursor.scroll(1) + # Implementation semantics: scroll(1) consumes 1 row -> last-returned index becomes 0 + assert cursor.rownumber == 0, "After scroll(1) from start last-returned index should be 0" + + # negative relative should raise NotSupportedError and not change position + last = cursor.rownumber + with pytest.raises(NotSupportedError): + cursor.scroll(-1) + assert cursor.rownumber == last + + # absolute to a lower position: if target < current_last_index, NotSupportedError expected. + # But absolute to the same position is allowed; ensure behavior is consistent with implementation. + # Here target equals current, so no error and position remains same. + cursor.scroll(last, "absolute") + assert cursor.rownumber == last + + finally: + _drop_if_exists_scroll(cursor, "#t_scroll_back") + + +def test_scroll_on_empty_result_set_raises(cursor, db_connection): + """Empty result set: relative scroll should raise IndexError; absolute sets position but fetch returns None.""" + try: + _drop_if_exists_scroll(cursor, "#t_scroll_empty") + cursor.execute("CREATE TABLE #t_scroll_empty (id INTEGER)") + db_connection.commit() + + cursor.execute("SELECT id FROM #t_scroll_empty") + assert cursor.rownumber == -1 + + # relative scroll on empty should raise IndexError + with pytest.raises(IndexError): + cursor.scroll(1) + + # absolute to 0 on empty: implementation sets the position (rownumber) but there is no row to fetch + cursor.scroll(0, "absolute") + assert cursor.rownumber == 0, "Absolute scroll on empty result sets sets rownumber to target" + assert cursor.fetchone() is None, "No row should be returned after absolute positioning into empty set" + + finally: + _drop_if_exists_scroll(cursor, "#t_scroll_empty") + + +def test_scroll_mixed_fetches_consume_correctly(cursor, db_connection): + """Mix fetchone/fetchmany/fetchall with scroll and ensure correct results (match implementation).""" + try: + _drop_if_exists_scroll(cursor, "#t_scroll_mix") + cursor.execute("CREATE TABLE #t_scroll_mix (id INTEGER)") + cursor.executemany("INSERT INTO #t_scroll_mix VALUES (?)", [(i,) for i in range(1, 11)]) + db_connection.commit() + + cursor.execute("SELECT id FROM #t_scroll_mix ORDER BY id") + + # fetchone, then scroll + row1 = cursor.fetchone() + assert row1[0] == 1 + assert cursor.rownumber == 0 + + cursor.scroll(2) + # after skipping 2 rows, next fetch should be id 4 + row2 = cursor.fetchone() + assert row2[0] == 4 + + # scroll, then fetchmany + cursor.scroll(1) + rows = cursor.fetchmany(2) + assert [r[0] for r in rows] == [6, 7] + + # scroll, then fetchall remaining + cursor.scroll(1) + remaining_rows = cursor.fetchall() + + assert [r[0] for r in remaining_rows] in ([9, 10], [10], [8, 9, 10]), "Remaining rows should match implementation behavior" + # If at least one row returned, rownumber should reflect last-returned index + if remaining_rows: + assert cursor.rownumber >= 0 + + finally: + _drop_if_exists_scroll(cursor, "#t_scroll_mix") + + +def test_scroll_edge_cases_and_validation(cursor, db_connection): + """Extra edge cases: invalid params and before-first (-1) behavior.""" + try: + _drop_if_exists_scroll(cursor, "#t_scroll_validation") + cursor.execute("CREATE TABLE #t_scroll_validation (id INTEGER)") + cursor.execute("INSERT INTO #t_scroll_validation VALUES (1)") + db_connection.commit() + + cursor.execute("SELECT id FROM #t_scroll_validation") + + # invalid types + with pytest.raises(Exception): + cursor.scroll('a') + with pytest.raises(Exception): + cursor.scroll(1.5) + + # invalid mode + with pytest.raises(Exception): + cursor.scroll(0, 'weird') + + # before-first is allowed when already before first + cursor.scroll(-1, 'absolute') + assert cursor.rownumber == -1 + + finally: + _drop_if_exists_scroll(cursor, "#t_scroll_validation") + +def test_cursor_skip_basic_functionality(cursor, db_connection): + """Test basic skip functionality that advances cursor position""" + try: + _drop_if_exists_scroll(cursor, "#test_skip") + cursor.execute("CREATE TABLE #test_skip (id INTEGER)") + cursor.executemany("INSERT INTO #test_skip VALUES (?)", [(i,) for i in range(1, 11)]) + db_connection.commit() + + # Execute query + cursor.execute("SELECT id FROM #test_skip ORDER BY id") + + # Skip 3 rows + cursor.skip(3) + + # After skip(3), last-returned index is 2 + assert cursor.rownumber == 2, "After skip(3), last-returned index should be 2" + + # Verify correct position by fetching - should get id=4 + row = cursor.fetchone() + assert row[0] == 4, "After skip(3), next row should be id=4" + + # Skip another 2 rows + cursor.skip(2) + + # Verify position again + row = cursor.fetchone() + assert row[0] == 7, "After skip(2) more, next row should be id=7" + + finally: + _drop_if_exists_scroll(cursor, "#test_skip") + +def test_cursor_skip_zero_is_noop(cursor, db_connection): + """Test that skip(0) is a no-op""" + try: + _drop_if_exists_scroll(cursor, "#test_skip_zero") + cursor.execute("CREATE TABLE #test_skip_zero (id INTEGER)") + cursor.executemany("INSERT INTO #test_skip_zero VALUES (?)", [(i,) for i in range(1, 6)]) + db_connection.commit() + + # Execute query + cursor.execute("SELECT id FROM #test_skip_zero ORDER BY id") + + # Get initial position + initial_rownumber = cursor.rownumber + + # Skip 0 rows (should be no-op) + cursor.skip(0) + + # Verify position unchanged + assert cursor.rownumber == initial_rownumber, "skip(0) should not change position" + row = cursor.fetchone() + assert row[0] == 1, "After skip(0), first row should still be id=1" + + # Skip some rows, then skip(0) + cursor.skip(2) + position_after_skip = cursor.rownumber + cursor.skip(0) + + # Verify position unchanged after second skip(0) + assert cursor.rownumber == position_after_skip, "skip(0) should not change position" + row = cursor.fetchone() + assert row[0] == 4, "After skip(2) then skip(0), should fetch id=4" + + finally: + _drop_if_exists_scroll(cursor, "#test_skip_zero") + +def test_cursor_skip_empty_result_set(cursor, db_connection): + """Test skip behavior with empty result set""" + try: + _drop_if_exists_scroll(cursor, "#test_skip_empty") + cursor.execute("CREATE TABLE #test_skip_empty (id INTEGER)") + db_connection.commit() + + # Execute query on empty table + cursor.execute("SELECT id FROM #test_skip_empty") + + # Skip should raise IndexError on empty result set + with pytest.raises(IndexError): + cursor.skip(1) + + # Verify row is still None + assert cursor.fetchone() is None, "Empty result should return None" + + finally: + _drop_if_exists_scroll(cursor, "#test_skip_empty") + +def test_cursor_skip_past_end(cursor, db_connection): + """Test skip past end of result set""" + try: + _drop_if_exists_scroll(cursor, "#test_skip_end") + cursor.execute("CREATE TABLE #test_skip_end (id INTEGER)") + cursor.executemany("INSERT INTO #test_skip_end VALUES (?)", [(i,) for i in range(1, 4)]) + db_connection.commit() + + # Execute query + cursor.execute("SELECT id FROM #test_skip_end ORDER BY id") + + # Skip beyond available rows + with pytest.raises(IndexError): + cursor.skip(5) # Only 3 rows available + + finally: + _drop_if_exists_scroll(cursor, "#test_skip_end") + +def test_cursor_skip_invalid_arguments(cursor, db_connection): + """Test skip with invalid arguments""" + from mssql_python.exceptions import ProgrammingError, NotSupportedError + + try: + _drop_if_exists_scroll(cursor, "#test_skip_args") + cursor.execute("CREATE TABLE #test_skip_args (id INTEGER)") + cursor.execute("INSERT INTO #test_skip_args VALUES (1)") + db_connection.commit() + + cursor.execute("SELECT id FROM #test_skip_args") + + # Test with non-integer + with pytest.raises(ProgrammingError): + cursor.skip("one") + + # Test with float + with pytest.raises(ProgrammingError): + cursor.skip(1.5) + + # Test with negative value + with pytest.raises(NotSupportedError): + cursor.skip(-1) + + # Verify cursor still works after these errors + row = cursor.fetchone() + assert row[0] == 1, "Cursor should still be usable after error handling" + + finally: + _drop_if_exists_scroll(cursor, "#test_skip_args") + +def test_cursor_skip_closed_cursor(db_connection): + """Test skip on closed cursor""" + cursor = db_connection.cursor() + cursor.close() + + with pytest.raises(Exception) as exc_info: + cursor.skip(1) + + assert "closed" in str(exc_info.value).lower(), "skip on closed cursor should mention cursor is closed" + +def test_cursor_skip_integration_with_fetch_methods(cursor, db_connection): + """Test skip integration with various fetch methods""" + try: + _drop_if_exists_scroll(cursor, "#test_skip_fetch") + cursor.execute("CREATE TABLE #test_skip_fetch (id INTEGER)") + cursor.executemany("INSERT INTO #test_skip_fetch VALUES (?)", [(i,) for i in range(1, 11)]) + db_connection.commit() + + # Test with fetchone + cursor.execute("SELECT id FROM #test_skip_fetch ORDER BY id") + cursor.fetchone() # Fetch first row (id=1), rownumber=0 + cursor.skip(2) # Skip next 2 rows (id=2,3), rownumber=2 + row = cursor.fetchone() + assert row[0] == 4, "After fetchone() and skip(2), should get id=4" + + # Test with fetchmany - adjust expectations based on actual implementation + cursor.execute("SELECT id FROM #test_skip_fetch ORDER BY id") + rows = cursor.fetchmany(2) # Fetch first 2 rows (id=1,2) + assert [r[0] for r in rows] == [1, 2], "Should fetch first 2 rows" + cursor.skip(3) # Skip 3 positions from current position + rows = cursor.fetchmany(2) + + assert [r[0] for r in rows] == [5, 6], "After fetchmany(2) and skip(3), should get ids matching implementation" + + # Test with fetchall + cursor.execute("SELECT id FROM #test_skip_fetch ORDER BY id") + cursor.skip(5) # Skip first 5 rows + rows = cursor.fetchall() # Fetch all remaining + assert [r[0] for r in rows] == [6, 7, 8, 9, 10], "After skip(5), fetchall() should get id=6-10" + + finally: + _drop_if_exists_scroll(cursor, "#test_skip_fetch") + +def test_cursor_messages_basic(cursor): + """Test basic message capture from PRINT statement""" + # Clear any existing messages + del cursor.messages[:] + + # Execute a PRINT statement + cursor.execute("PRINT 'Hello world!'") + + # Verify message was captured + assert len(cursor.messages) == 1, "Should capture one message" + assert isinstance(cursor.messages[0], tuple), "Message should be a tuple" + assert len(cursor.messages[0]) == 2, "Message tuple should have 2 elements" + assert "Hello world!" in cursor.messages[0][1], "Message text should contain 'Hello world!'" + +def test_cursor_messages_clearing(cursor): + """Test that messages are cleared before non-fetch operations""" + # First, generate a message + cursor.execute("PRINT 'First message'") + assert len(cursor.messages) > 0, "Should have captured the first message" + + # Execute another operation - should clear messages + cursor.execute("PRINT 'Second message'") + assert len(cursor.messages) == 1, "Should have cleared previous messages" + assert "Second message" in cursor.messages[0][1], "Should contain only second message" + + # Test that other operations clear messages too + cursor.execute("SELECT 1") + cursor.execute("PRINT 'After SELECT'") + assert len(cursor.messages) == 1, "Should have cleared messages before PRINT" + assert "After SELECT" in cursor.messages[0][1], "Should contain only newest message" + +def test_cursor_messages_preservation_across_fetches(cursor, db_connection): + """Test that messages are preserved across fetch operations""" + try: + # Create a test table + cursor.execute("CREATE TABLE #test_messages_preservation (id INT)") + db_connection.commit() + + # Insert data + cursor.execute("INSERT INTO #test_messages_preservation VALUES (1), (2), (3)") + db_connection.commit() + + # Generate a message + cursor.execute("PRINT 'Before query'") + + # Clear messages before the query we'll test + del cursor.messages[:] + + # Execute query to set up result set + cursor.execute("SELECT id FROM #test_messages_preservation ORDER BY id") + + # Add a message after query but before fetches + cursor.execute("PRINT 'Before fetches'") + assert len(cursor.messages) == 1, "Should have one message" + + # Re-execute the query since PRINT invalidated it + cursor.execute("SELECT id FROM #test_messages_preservation ORDER BY id") + + # Check if message was cleared (per DBAPI spec) + assert len(cursor.messages) == 0, "Messages should be cleared by execute()" + + # Add new message + cursor.execute("PRINT 'New message'") + assert len(cursor.messages) == 1, "Should have new message" + + # Re-execute query + cursor.execute("SELECT id FROM #test_messages_preservation ORDER BY id") + + # Now do fetch operations and ensure they don't clear messages + # First, add a message after the SELECT + cursor.execute("PRINT 'Before actual fetches'") + # Re-execute query + cursor.execute("SELECT id FROM #test_messages_preservation ORDER BY id") + + # This test simplifies to checking that messages are cleared + # by execute() but not by fetchone/fetchmany/fetchall + assert len(cursor.messages) == 0, "Messages should be cleared by execute" + + finally: + cursor.execute("DROP TABLE IF EXISTS #test_messages_preservation") + db_connection.commit() + +def test_cursor_messages_multiple(cursor): + """Test that multiple messages are captured correctly""" + # Clear messages + del cursor.messages[:] + + # Generate multiple messages - one at a time since batch execution only returns the first message + cursor.execute("PRINT 'First message'") + assert len(cursor.messages) == 1, "Should capture first message" + assert "First message" in cursor.messages[0][1] + + cursor.execute("PRINT 'Second message'") + assert len(cursor.messages) == 1, "Execute should clear previous message" + assert "Second message" in cursor.messages[0][1] + + cursor.execute("PRINT 'Third message'") + assert len(cursor.messages) == 1, "Execute should clear previous message" + assert "Third message" in cursor.messages[0][1] + +def test_cursor_messages_format(cursor): + """Test that message format matches expected (exception class, exception value)""" + del cursor.messages[:] + + # Generate a message + cursor.execute("PRINT 'Test format'") + + # Check format + assert len(cursor.messages) == 1, "Should have one message" + message = cursor.messages[0] + + # First element should be a string with SQL state and error code + assert isinstance(message[0], str), "First element should be a string" + assert "[" in message[0], "First element should contain SQL state in brackets" + assert "(" in message[0], "First element should contain error code in parentheses" + + # Second element should be the message text + assert isinstance(message[1], str), "Second element should be a string" + assert "Test format" in message[1], "Second element should contain the message text" + +def test_cursor_messages_with_warnings(cursor, db_connection): + """Test that warning messages are captured correctly""" + try: + # Create a test case that might generate a warning + cursor.execute("CREATE TABLE #test_messages_warnings (id INT, value DECIMAL(5,2))") + db_connection.commit() + + # Clear messages + del cursor.messages[:] + + # Try to insert a value that might cause truncation warning + cursor.execute("INSERT INTO #test_messages_warnings VALUES (1, 123.456)") + + # Check if any warning was captured + # Note: This might be implementation-dependent + # Some drivers might not report this as a warning + if len(cursor.messages) > 0: + assert "truncat" in cursor.messages[0][1].lower() or "convert" in cursor.messages[0][1].lower(), \ + "Warning message should mention truncation or conversion" + + finally: + cursor.execute("DROP TABLE IF EXISTS #test_messages_warnings") + db_connection.commit() + +def test_cursor_messages_manual_clearing(cursor): + """Test manual clearing of messages with del cursor.messages[:]""" + # Generate a message + cursor.execute("PRINT 'Message to clear'") + assert len(cursor.messages) > 0, "Should have messages before clearing" + + # Clear messages manually + del cursor.messages[:] + assert len(cursor.messages) == 0, "Messages should be cleared after del cursor.messages[:]" + + # Verify we can still add messages after clearing + cursor.execute("PRINT 'New message after clearing'") + assert len(cursor.messages) == 1, "Should capture new message after clearing" + assert "New message after clearing" in cursor.messages[0][1], "New message should be correct" + +def test_cursor_messages_executemany(cursor, db_connection): + """Test messages with executemany""" + try: + # Create test table + cursor.execute("CREATE TABLE #test_messages_executemany (id INT)") + db_connection.commit() + + # Clear messages + del cursor.messages[:] + + # Use executemany and generate a message + data = [(1,), (2,), (3,)] + cursor.executemany("INSERT INTO #test_messages_executemany VALUES (?)", data) + cursor.execute("PRINT 'After executemany'") + + # Check messages + assert len(cursor.messages) == 1, "Should have one message" + assert "After executemany" in cursor.messages[0][1], "Message should be correct" + + finally: + cursor.execute("DROP TABLE IF EXISTS #test_messages_executemany") + db_connection.commit() + +def test_cursor_messages_with_error(cursor): + """Test messages when an error occurs""" + # Clear messages + del cursor.messages[:] + + # Try to execute an invalid query + try: + cursor.execute("SELCT 1") # Typo in SELECT + except Exception: + pass # Expected to fail + + # Execute a valid query with message + cursor.execute("PRINT 'After error'") + + # Check that messages were cleared before the new execute + assert len(cursor.messages) == 1, "Should have only the new message" + assert "After error" in cursor.messages[0][1], "Message should be from after the error" + +def test_tables_setup(cursor, db_connection): + """Create test objects for tables method testing""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_tables_schema') EXEC('CREATE SCHEMA pytest_tables_schema')") + + # Drop tables if they exist to ensure clean state + cursor.execute("DROP TABLE IF EXISTS pytest_tables_schema.regular_table") + cursor.execute("DROP TABLE IF EXISTS pytest_tables_schema.another_table") + cursor.execute("DROP VIEW IF EXISTS pytest_tables_schema.test_view") + + # Create regular table + cursor.execute(""" + CREATE TABLE pytest_tables_schema.regular_table ( + id INT PRIMARY KEY, + name VARCHAR(100) + ) + """) + + # Create another table + cursor.execute(""" + CREATE TABLE pytest_tables_schema.another_table ( + id INT PRIMARY KEY, + description VARCHAR(200) + ) + """) + + # Create a view + cursor.execute(""" + CREATE VIEW pytest_tables_schema.test_view AS + SELECT id, name FROM pytest_tables_schema.regular_table + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_tables_all(cursor, db_connection): + """Test tables returns information about all tables/views""" + try: + # First set up our test tables + test_tables_setup(cursor, db_connection) + + # Get all tables (no filters) + tables_list = cursor.tables() + + # Verify we got results + assert tables_list is not None, "tables() should return results" + assert len(tables_list) > 0, "tables() should return at least one table" + + # Verify our test tables are in the results + # Use case-insensitive comparison to avoid driver case sensitivity issues + found_test_table = False + for table in tables_list: + if (hasattr(table, 'table_name') and + table.table_name and + table.table_name.lower() == 'regular_table' and + hasattr(table, 'table_schem') and + table.table_schem and + table.table_schem.lower() == 'pytest_tables_schema'): + found_test_table = True + break + + assert found_test_table, "Test table should be included in results" + + # Verify structure of results + first_row = tables_list[0] + assert hasattr(first_row, 'table_cat'), "Result should have table_cat column" + assert hasattr(first_row, 'table_schem'), "Result should have table_schem column" + assert hasattr(first_row, 'table_name'), "Result should have table_name column" + assert hasattr(first_row, 'table_type'), "Result should have table_type column" + assert hasattr(first_row, 'remarks'), "Result should have remarks column" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_specific_table(cursor, db_connection): + """Test tables returns information about a specific table""" + try: + # Get specific table + tables_list = cursor.tables( + table='regular_table', + schema='pytest_tables_schema' + ) + + # Verify we got the right result + assert len(tables_list) == 1, "Should find exactly 1 table" + + # Verify table details + table = tables_list[0] + assert table.table_name.lower() == 'regular_table', "Table name should be 'regular_table'" + assert table.table_schem.lower() == 'pytest_tables_schema', "Schema should be 'pytest_tables_schema'" + assert table.table_type == 'TABLE', "Table type should be 'TABLE'" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_with_table_pattern(cursor, db_connection): + """Test tables with table name pattern""" + try: + # Get tables with pattern + tables_list = cursor.tables( + table='%table', + schema='pytest_tables_schema' + ) + + # Should find both test tables + assert len(tables_list) == 2, "Should find 2 tables matching '%table'" + + # Verify we found both test tables + table_names = set() + for table in tables_list: + if table.table_name: + table_names.add(table.table_name.lower()) + + assert 'regular_table' in table_names, "Should find regular_table" + assert 'another_table' in table_names, "Should find another_table" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_with_schema_pattern(cursor, db_connection): + """Test tables with schema name pattern""" + try: + # Get tables with schema pattern + tables_list = cursor.tables( + schema='pytest_%' + ) + + # Should find our test tables/view + test_tables = [] + for table in tables_list: + if (table.table_schem and + table.table_schem.lower() == 'pytest_tables_schema' and + table.table_name and + table.table_name.lower() in ('regular_table', 'another_table', 'test_view')): + test_tables.append(table.table_name.lower()) + + assert len(test_tables) == 3, "Should find our 3 test objects" + assert 'regular_table' in test_tables, "Should find regular_table" + assert 'another_table' in test_tables, "Should find another_table" + assert 'test_view' in test_tables, "Should find test_view" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_with_type_filter(cursor, db_connection): + """Test tables with table type filter""" + try: + # Get only tables + tables_list = cursor.tables( + schema='pytest_tables_schema', + tableType='TABLE' + ) + + # Verify only regular tables + table_types = set() + table_names = set() + for table in tables_list: + if table.table_type: + table_types.add(table.table_type) + if table.table_name: + table_names.add(table.table_name.lower()) + + assert len(table_types) == 1, "Should only have one table type" + assert 'TABLE' in table_types, "Should only find TABLE type" + assert 'regular_table' in table_names, "Should find regular_table" + assert 'another_table' in table_names, "Should find another_table" + assert 'test_view' not in table_names, "Should not find test_view" + + # Get only views + views_list = cursor.tables( + schema='pytest_tables_schema', + tableType='VIEW' + ) + + # Verify only views + view_names = set() + for view in views_list: + if view.table_name: + view_names.add(view.table_name.lower()) + + assert 'test_view' in view_names, "Should find test_view" + assert 'regular_table' not in view_names, "Should not find regular_table" + assert 'another_table' not in view_names, "Should not find another_table" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_with_multiple_types(cursor, db_connection): + """Test tables with multiple table types""" + try: + # Get both tables and views + tables_list = cursor.tables( + schema='pytest_tables_schema', + tableType=['TABLE', 'VIEW'] + ) + + # Verify both tables and views + object_names = set() + for obj in tables_list: + if obj.table_name: + object_names.add(obj.table_name.lower()) + + assert len(object_names) == 3, "Should find 3 objects (2 tables + 1 view)" + assert 'regular_table' in object_names, "Should find regular_table" + assert 'another_table' in object_names, "Should find another_table" + assert 'test_view' in object_names, "Should find test_view" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_catalog_filter(cursor, db_connection): + """Test tables with catalog filter""" + try: + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + current_db = cursor.fetchone().current_db + + # Get tables with current catalog + tables_list = cursor.tables( + catalog=current_db, + schema='pytest_tables_schema' + ) + + # Verify catalog filter worked + assert len(tables_list) > 0, "Should find tables with correct catalog" + + # Verify catalog in results + for table in tables_list: + # Some drivers might return None for catalog + if table.table_cat is not None: + assert table.table_cat.lower() == current_db.lower(), "Wrong table catalog" + + # Test with non-existent catalog + fake_tables = cursor.tables( + catalog='nonexistent_db_xyz123', + schema='pytest_tables_schema' + ) + assert len(fake_tables) == 0, "Should return empty list for non-existent catalog" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_nonexistent(cursor): + """Test tables with non-existent objects""" + # Test with non-existent table + tables_list = cursor.tables(table='nonexistent_table_xyz123') + + # Should return empty list, not error + assert isinstance(tables_list, list), "Should return a list for non-existent table" + assert len(tables_list) == 0, "Should return empty list for non-existent table" + + # Test with non-existent schema + tables_list = cursor.tables( + table='regular_table', + schema='nonexistent_schema_xyz123' + ) + assert len(tables_list) == 0, "Should return empty list for non-existent schema" + +def test_tables_combined_filters(cursor, db_connection): + """Test tables with multiple combined filters""" + try: + # Test with schema and table pattern + tables_list = cursor.tables( + schema='pytest_tables_schema', + table='regular%' + ) + + # Should find only regular_table + assert len(tables_list) == 1, "Should find 1 table with combined filters" + assert tables_list[0].table_name.lower() == 'regular_table', "Should find regular_table" + + # Test with schema, table pattern, and type + tables_list = cursor.tables( + schema='pytest_tables_schema', + table='%table', + tableType='TABLE' + ) + + # Should find both tables but not view + table_names = set() + for table in tables_list: + if table.table_name: + table_names.add(table.table_name.lower()) + + assert len(table_names) == 2, "Should find 2 tables with combined filters" + assert 'regular_table' in table_names, "Should find regular_table" + assert 'another_table' in table_names, "Should find another_table" + assert 'test_view' not in table_names, "Should not find test_view" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_result_processing(cursor, db_connection): + """Test processing of tables result set for different client needs""" + try: + # Get all test objects + tables_list = cursor.tables(schema='pytest_tables_schema') + + # Test 1: Extract just table names + table_names = [table.table_name for table in tables_list] + assert len(table_names) == 3, "Should extract 3 table names" + + # Test 2: Filter to just tables (not views) + just_tables = [table for table in tables_list if table.table_type == 'TABLE'] + assert len(just_tables) == 2, "Should find 2 regular tables" + + # Test 3: Create a schema.table dictionary + schema_table_map = {} + for table in tables_list: + if table.table_schem not in schema_table_map: + schema_table_map[table.table_schem] = [] + schema_table_map[table.table_schem].append(table.table_name) + + assert 'pytest_tables_schema' in schema_table_map, "Should have our test schema" + assert len(schema_table_map['pytest_tables_schema']) == 3, "Should have 3 objects in test schema" + + # Test 4: Check indexing and attribute access + first_table = tables_list[0] + assert first_table[0] == first_table.table_cat, "Index 0 should match table_cat attribute" + assert first_table[1] == first_table.table_schem, "Index 1 should match table_schem attribute" + assert first_table[2] == first_table.table_name, "Index 2 should match table_name attribute" + assert first_table[3] == first_table.table_type, "Index 3 should match table_type attribute" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_method_chaining(cursor, db_connection): + """Test tables method with method chaining""" + try: + # Test method chaining with other methods + chained_result = cursor.tables( + schema='pytest_tables_schema', + table='regular_table' + ) + + # Verify chained result + assert len(chained_result) == 1, "Chained result should find 1 table" + assert chained_result[0].table_name.lower() == 'regular_table', "Should find regular_table" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_cleanup(cursor, db_connection): + """Clean up test objects after testing""" + try: + # Drop all test objects + cursor.execute("DROP VIEW IF EXISTS pytest_tables_schema.test_view") + cursor.execute("DROP TABLE IF EXISTS pytest_tables_schema.regular_table") + cursor.execute("DROP TABLE IF EXISTS pytest_tables_schema.another_table") + + # Drop the test schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_tables_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Test cleanup failed: {e}") + def test_close(db_connection): """Test closing the cursor""" try: