From fcbdd71671dededa85d733ca3dea3a5b5647d1d5 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Mon, 18 Aug 2025 12:55:29 +0530 Subject: [PATCH 1/2] FEAT: Adding implementation for Cursor.message --- mssql_python/cursor.py | 89 +++++++++--- mssql_python/pybind/ddbc_bindings.cpp | 63 ++++++++ tests/test_004_cursor.py | 200 ++++++++++++++++++++++++++ 3 files changed, 334 insertions(+), 18 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index cad7fcca0..9c8b7000d 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -80,6 +80,8 @@ def __init__(self, connection) -> None: self._next_row_index = 0 # internal: index of the next row the driver will return (0-based) self._has_result_set = False # Track if we have an active result set + self.messages = [] # Store diagnostic messages + def _is_unicode_string(self, param): """ Check if a string contains non-ASCII characters. @@ -452,6 +454,9 @@ def close(self) -> None: if self.closed: raise Exception("Cursor is already closed.") + # Clear messages per DBAPI + self.messages = [] + if self.hstmt: self.hstmt.free() self.hstmt = None @@ -695,6 +700,9 @@ def execute( if reset_cursor: self._reset_cursor() + # Clear any previous messages + self.messages = [] + param_info = ddbc_bindings.ParamInfo parameters_type = [] @@ -742,7 +750,14 @@ def execute( self.is_stmt_prepared, use_prepare, ) + + # Check for errors but don't raise exceptions for info/warning messages check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + + # Capture any diagnostic messages (SQL_SUCCESS_WITH_INFO, etc.) + if self.hstmt: + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) + self.last_executed_stmt = operation # Update rowcount after execution @@ -822,7 +837,10 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: """ self._check_closed() self._reset_cursor() - + + # Clear any previous messages + self.messages = [] + if not seq_of_parameters: self.rowcount = 0 return @@ -854,6 +872,10 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: ) check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + # Capture any diagnostic messages after execution + if self.hstmt: + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) + self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt) self.last_executed_stmt = operation self._initialize_description() @@ -877,6 +899,9 @@ def fetchone(self) -> Union[None, Row]: try: ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data) + if self.hstmt: + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) + if ret == ddbc_sql_const.SQL_NO_DATA.value: return None @@ -911,6 +936,10 @@ def fetchmany(self, size: int = None) -> List[Row]: rows_data = [] try: ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size) + + if self.hstmt: + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) + # Update rownumber for the number of rows actually fetched if rows_data and self._has_result_set: @@ -937,6 +966,10 @@ def fetchall(self) -> List[Row]: rows_data = [] try: ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data) + + if self.hstmt: + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) + # Update rownumber for the number of rows actually fetched if rows_data and self._has_result_set: @@ -961,6 +994,9 @@ def nextset(self) -> Union[bool, None]: """ self._check_closed() # Check if the cursor is closed + # Clear messages per DBAPI + self.messages = [] + # Skip to the next result set ret = ddbc_bindings.DDBCSQLMoreResults(self.hstmt) check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) @@ -1041,6 +1077,9 @@ def commit(self): """ self._check_closed() # Check if the cursor is closed + # Clear messages per DBAPI + self.messages = [] + # Delegate to the connection's commit method self._connection.commit() @@ -1067,6 +1106,9 @@ def rollback(self): """ self._check_closed() # Check if the cursor is closed + # Clear messages per DBAPI + self.messages = [] + # Delegate to the connection's rollback method self._connection.rollback() @@ -1090,6 +1132,10 @@ def scroll(self, value: int, mode: str = 'relative') -> None: This implementation emulates scrolling for forward-only cursors by consuming rows. """ self._check_closed() + + # Clear messages per DBAPI + self.messages = [] + if mode not in ('relative', 'absolute'): raise ProgrammingError( driver_error="Invalid scroll mode", @@ -1195,29 +1241,36 @@ def _consume_rows_for_scroll(self, rows_to_consume: int) -> None: def skip(self, count: int) -> None: """ - Skip the next 'count' records in the query result set. - - This is a convenience method that advances the cursor by 'count' - positions without returning the skipped rows. + Skip the next count records in the query result set. Args: - count: Number of records to skip. Must be non-negative. - - Returns: - None + count: Number of records to skip. Raises: - ProgrammingError: If the cursor is closed or no result set is available. - NotSupportedError: If count is negative (backward scrolling not supported). IndexError: If attempting to skip past the end of the result set. - - Note: - For convenience, skip(0) is accepted and will do nothing. + ProgrammingError: If count is not an integer. + NotSupportedError: If attempting to skip backwards. """ + from mssql_python.exceptions import ProgrammingError, NotSupportedError + self._check_closed() - if count == 0: # Skip 0 is a no-op + # Clear messages + self.messages = [] + + # Validate arguments + if not isinstance(count, int): + raise ProgrammingError("Count must be an integer", "Invalid argument type") + + if count < 0: + raise NotSupportedError("Negative skip values are not supported", "Backward scrolling not supported") + + # Skip zero is a no-op + if count == 0: return - - # Use existing scroll method with relative mode - self.scroll(count, 'relative') \ No newline at end of file + + # Skip the rows by fetching and discarding + for _ in range(count): + row = self.fetchone() + if row is None: + raise IndexError("Cannot skip beyond the end of the result set") diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 1b37b8f0f..a5bcc4466 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -901,6 +901,65 @@ ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRET return errorInfo; } +py::list SQLGetAllDiagRecords(SqlHandlePtr handle) { + LOG("Retrieving all diagnostic records"); + if (!SQLGetDiagRec_ptr) { + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); + } + + py::list records; + SQLHANDLE rawHandle = handle->get(); + SQLSMALLINT handleType = handle->type(); + + // Iterate through all available diagnostic records + for (SQLSMALLINT recNumber = 1; ; recNumber++) { + SQLWCHAR sqlState[6] = {0}; + SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0}; + SQLINTEGER nativeError = 0; + SQLSMALLINT messageLen = 0; + + SQLRETURN diagReturn = SQLGetDiagRec_ptr( + handleType, rawHandle, recNumber, sqlState, &nativeError, + message, SQL_MAX_MESSAGE_LENGTH, &messageLen); + + if (diagReturn == SQL_NO_DATA || !SQL_SUCCEEDED(diagReturn)) + break; + +#if defined(_WIN32) + // On Windows, create a formatted UTF-8 string for state+error + char stateWithError[50]; + sprintf(stateWithError, "[%ls] (%d)", sqlState, nativeError); + + // Convert wide string message to UTF-8 + int msgSize = WideCharToMultiByte(CP_UTF8, 0, message, -1, NULL, 0, NULL, NULL); + std::vector msgBuffer(msgSize); + WideCharToMultiByte(CP_UTF8, 0, message, -1, msgBuffer.data(), msgSize, NULL, NULL); + + // Create the tuple with converted strings + records.append(py::make_tuple( + py::str(stateWithError), + py::str(msgBuffer.data()) + )); +#else + // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 + std::string stateStr = WideToUTF8(SQLWCHARToWString(sqlState)); + std::string msgStr = WideToUTF8(SQLWCHARToWString(message, messageLen)); + + // Format the state string + std::string stateWithError = "[" + stateStr + "] (" + std::to_string(nativeError) + ")"; + + // Create the tuple with converted strings + records.append(py::make_tuple( + py::str(stateWithError), + py::str(msgStr) + )); +#endif + } + + return records; +} + // Wrap SQLExecDirect SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Query) { LOG("Execute SQL query directly - {}", Query.c_str()); @@ -2553,6 +2612,10 @@ PYBIND11_MODULE(ddbc_bindings, m) { m.def("DDBCSQLFetchAll", &FetchAll_wrap, "Fetch all rows from the result set"); m.def("DDBCSQLFreeHandle", &SQLFreeHandle_wrap, "Free a handle"); m.def("DDBCSQLCheckError", &SQLCheckError_Wrap, "Check for driver errors"); + // Add this to your PYBIND11_MODULE section + m.def("DDBCSQLGetAllDiagRecords", &SQLGetAllDiagRecords, + "Get all diagnostic records for a handle", + py::arg("handle")); // Add a version attribute m.attr("__version__") = "1.0.0"; diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 5221426ce..b43097d6e 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -4204,6 +4204,206 @@ def test_cursor_skip_integration_with_fetch_methods(cursor, db_connection): finally: _drop_if_exists_scroll(cursor, "#test_skip_fetch") +def test_cursor_messages_basic(cursor): + """Test basic message capture from PRINT statement""" + # Clear any existing messages + del cursor.messages[:] + + # Execute a PRINT statement + cursor.execute("PRINT 'Hello world!'") + + # Verify message was captured + assert len(cursor.messages) == 1, "Should capture one message" + assert isinstance(cursor.messages[0], tuple), "Message should be a tuple" + assert len(cursor.messages[0]) == 2, "Message tuple should have 2 elements" + assert "Hello world!" in cursor.messages[0][1], "Message text should contain 'Hello world!'" + +def test_cursor_messages_clearing(cursor): + """Test that messages are cleared before non-fetch operations""" + # First, generate a message + cursor.execute("PRINT 'First message'") + assert len(cursor.messages) > 0, "Should have captured the first message" + + # Execute another operation - should clear messages + cursor.execute("PRINT 'Second message'") + assert len(cursor.messages) == 1, "Should have cleared previous messages" + assert "Second message" in cursor.messages[0][1], "Should contain only second message" + + # Test that other operations clear messages too + cursor.execute("SELECT 1") + cursor.execute("PRINT 'After SELECT'") + assert len(cursor.messages) == 1, "Should have cleared messages before PRINT" + assert "After SELECT" in cursor.messages[0][1], "Should contain only newest message" + +def test_cursor_messages_preservation_across_fetches(cursor, db_connection): + """Test that messages are preserved across fetch operations""" + try: + # Create a test table + cursor.execute("CREATE TABLE #test_messages_preservation (id INT)") + db_connection.commit() + + # Insert data + cursor.execute("INSERT INTO #test_messages_preservation VALUES (1), (2), (3)") + db_connection.commit() + + # Generate a message + cursor.execute("PRINT 'Before query'") + + # Clear messages before the query we'll test + del cursor.messages[:] + + # Execute query to set up result set + cursor.execute("SELECT id FROM #test_messages_preservation ORDER BY id") + + # Add a message after query but before fetches + cursor.execute("PRINT 'Before fetches'") + assert len(cursor.messages) == 1, "Should have one message" + + # Re-execute the query since PRINT invalidated it + cursor.execute("SELECT id FROM #test_messages_preservation ORDER BY id") + + # Check if message was cleared (per DBAPI spec) + assert len(cursor.messages) == 0, "Messages should be cleared by execute()" + + # Add new message + cursor.execute("PRINT 'New message'") + assert len(cursor.messages) == 1, "Should have new message" + + # Re-execute query + cursor.execute("SELECT id FROM #test_messages_preservation ORDER BY id") + + # Now do fetch operations and ensure they don't clear messages + # First, add a message after the SELECT + cursor.execute("PRINT 'Before actual fetches'") + # Re-execute query + cursor.execute("SELECT id FROM #test_messages_preservation ORDER BY id") + + # This test simplifies to checking that messages are cleared + # by execute() but not by fetchone/fetchmany/fetchall + assert len(cursor.messages) == 0, "Messages should be cleared by execute" + + finally: + cursor.execute("DROP TABLE IF EXISTS #test_messages_preservation") + db_connection.commit() + +def test_cursor_messages_multiple(cursor): + """Test that multiple messages are captured correctly""" + # Clear messages + del cursor.messages[:] + + # Generate multiple messages - one at a time since batch execution only returns the first message + cursor.execute("PRINT 'First message'") + assert len(cursor.messages) == 1, "Should capture first message" + assert "First message" in cursor.messages[0][1] + + cursor.execute("PRINT 'Second message'") + assert len(cursor.messages) == 1, "Execute should clear previous message" + assert "Second message" in cursor.messages[0][1] + + cursor.execute("PRINT 'Third message'") + assert len(cursor.messages) == 1, "Execute should clear previous message" + assert "Third message" in cursor.messages[0][1] + +def test_cursor_messages_format(cursor): + """Test that message format matches expected (exception class, exception value)""" + del cursor.messages[:] + + # Generate a message + cursor.execute("PRINT 'Test format'") + + # Check format + assert len(cursor.messages) == 1, "Should have one message" + message = cursor.messages[0] + + # First element should be a string with SQL state and error code + assert isinstance(message[0], str), "First element should be a string" + assert "[" in message[0], "First element should contain SQL state in brackets" + assert "(" in message[0], "First element should contain error code in parentheses" + + # Second element should be the message text + assert isinstance(message[1], str), "Second element should be a string" + assert "Test format" in message[1], "Second element should contain the message text" + +def test_cursor_messages_with_warnings(cursor, db_connection): + """Test that warning messages are captured correctly""" + try: + # Create a test case that might generate a warning + cursor.execute("CREATE TABLE #test_messages_warnings (id INT, value DECIMAL(5,2))") + db_connection.commit() + + # Clear messages + del cursor.messages[:] + + # Try to insert a value that might cause truncation warning + cursor.execute("INSERT INTO #test_messages_warnings VALUES (1, 123.456)") + + # Check if any warning was captured + # Note: This might be implementation-dependent + # Some drivers might not report this as a warning + if len(cursor.messages) > 0: + assert "truncat" in cursor.messages[0][1].lower() or "convert" in cursor.messages[0][1].lower(), \ + "Warning message should mention truncation or conversion" + + finally: + cursor.execute("DROP TABLE IF EXISTS #test_messages_warnings") + db_connection.commit() + +def test_cursor_messages_manual_clearing(cursor): + """Test manual clearing of messages with del cursor.messages[:]""" + # Generate a message + cursor.execute("PRINT 'Message to clear'") + assert len(cursor.messages) > 0, "Should have messages before clearing" + + # Clear messages manually + del cursor.messages[:] + assert len(cursor.messages) == 0, "Messages should be cleared after del cursor.messages[:]" + + # Verify we can still add messages after clearing + cursor.execute("PRINT 'New message after clearing'") + assert len(cursor.messages) == 1, "Should capture new message after clearing" + assert "New message after clearing" in cursor.messages[0][1], "New message should be correct" + +def test_cursor_messages_executemany(cursor, db_connection): + """Test messages with executemany""" + try: + # Create test table + cursor.execute("CREATE TABLE #test_messages_executemany (id INT)") + db_connection.commit() + + # Clear messages + del cursor.messages[:] + + # Use executemany and generate a message + data = [(1,), (2,), (3,)] + cursor.executemany("INSERT INTO #test_messages_executemany VALUES (?)", data) + cursor.execute("PRINT 'After executemany'") + + # Check messages + assert len(cursor.messages) == 1, "Should have one message" + assert "After executemany" in cursor.messages[0][1], "Message should be correct" + + finally: + cursor.execute("DROP TABLE IF EXISTS #test_messages_executemany") + db_connection.commit() + +def test_cursor_messages_with_error(cursor): + """Test messages when an error occurs""" + # Clear messages + del cursor.messages[:] + + # Try to execute an invalid query + try: + cursor.execute("SELCT 1") # Typo in SELECT + except Exception: + pass # Expected to fail + + # Execute a valid query with message + cursor.execute("PRINT 'After error'") + + # Check that messages were cleared before the new execute + assert len(cursor.messages) == 1, "Should have only the new message" + assert "After error" in cursor.messages[0][1], "Message should be from after the error" + def test_close(db_connection): """Test closing the cursor""" try: From a1db27561dbcdd1ea72c7f5a5083dcf1e05223e2 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar <61936179+jahnvi480@users.noreply.github.com> Date: Wed, 27 Aug 2025 11:20:30 +0530 Subject: [PATCH 2/2] FEAT: adding cursor.tables (#185) ### Work Item / Issue Reference > [AB#34926](https://sqlclientdrivers.visualstudio.com/c6d89619-62de-46a0-8b46-70b92a84d85e/_workitems/edit/34926) ------------------------------------------------------------------- ### Summary This pull request adds a new `tables()` method to the `Cursor` class in `mssql_python/cursor.py`, providing a way to query metadata about tables in the database, including support for filtering by name, schema, catalog, and table type. It also introduces comprehensive test coverage for this new method in `tests/test_004_cursor.py`. Additionally, the `skip()` method in the cursor is simplified by delegating to the existing `scroll()` method. **New feature: Table metadata querying** - Added a `tables()` method to the `Cursor` class, enabling users to retrieve information about tables with support for filtering by table name (including temporary tables), schema, catalog, and table type (supports both string and list input). The method returns the cursor itself for easy chaining and iteration. **Testing improvements** - Introduced a suite of tests for the new `tables()` method, covering basic usage, filtering by name, schema, and type, wildcard support, combined filters, empty results, iteration, method chaining, and existence checks. These tests ensure the method works as intended and handles edge cases. **Code simplification** - Refactored the `skip()` method in the cursor to delegate to the `scroll()` method in 'relative' mode, removing redundant validation and manual row skipping logic. --------- Co-authored-by: Jahnvi Thakkar --- mssql_python/cursor.py | 157 +++++++++-- mssql_python/pybind/ddbc_bindings.cpp | 95 ++++++- mssql_python/pybind/ddbc_bindings.h | 14 +- mssql_python/row.py | 30 +-- tests/test_004_cursor.py | 370 +++++++++++++++++++++++++- 5 files changed, 631 insertions(+), 35 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 9c8b7000d..7356578de 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -908,8 +908,9 @@ def fetchone(self) -> Union[None, Row]: # Update internal position after successful fetch self._increment_rownumber() - # Create and return a Row object - return Row(row_data, self.description) + # Create and return a Row object, passing column name map if available + column_map = getattr(self, '_column_name_map', None) + return Row(row_data, self.description, column_map) except Exception as e: # On error, don't increment rownumber - rethrow the error raise e @@ -948,7 +949,8 @@ def fetchmany(self, size: int = None) -> List[Row]: self._rownumber = self._next_row_index - 1 # Convert raw data to Row objects - return [Row(row_data, self.description) for row_data in rows_data] + column_map = getattr(self, '_column_name_map', None) + return [Row(row_data, self.description, column_map) for row_data in rows_data] except Exception as e: # On error, don't increment rownumber - rethrow the error raise e @@ -977,7 +979,8 @@ def fetchall(self) -> List[Row]: self._rownumber = self._next_row_index - 1 # Convert raw data to Row objects - return [Row(row_data, self.description) for row_data in rows_data] + column_map = getattr(self, '_column_name_map', None) + return [Row(row_data, self.description, column_map) for row_data in rows_data] except Exception as e: # On error, don't increment rownumber - rethrow the error raise e @@ -1258,19 +1261,139 @@ def skip(self, count: int) -> None: # Clear messages self.messages = [] - # Validate arguments - if not isinstance(count, int): - raise ProgrammingError("Count must be an integer", "Invalid argument type") + # Simply delegate to the scroll method with 'relative' mode + self.scroll(count, 'relative') + + def _execute_tables(self, stmt_handle, catalog_name=None, schema_name=None, table_name=None, + table_type=None, search_escape=None): + """ + Execute SQLTables ODBC function to retrieve table metadata. - if count < 0: - raise NotSupportedError("Negative skip values are not supported", "Backward scrolling not supported") + Args: + stmt_handle: ODBC statement handle + catalog_name: The catalog name pattern + schema_name: The schema name pattern + table_name: The table name pattern + table_type: The table type filter + search_escape: The escape character for pattern matching + """ + # Convert None values to empty strings for ODBC + catalog = "" if catalog_name is None else catalog_name + schema = "" if schema_name is None else schema_name + table = "" if table_name is None else table_name + types = "" if table_type is None else table_type - # Skip zero is a no-op - if count == 0: - return + # Call the ODBC SQLTables function + retcode = ddbc_bindings.DDBCSQLTables( + stmt_handle, + catalog, + schema, + table, + types + ) + + # Check return code and handle errors + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, stmt_handle, retcode) + + # Capture any diagnostic messages + if stmt_handle: + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(stmt_handle)) + + def tables(self, table=None, catalog=None, schema=None, tableType=None): + """ + Returns information about tables in the database that match the given criteria using + the SQLTables ODBC function. + + Args: + table (str, optional): The table name pattern. Default is None (all tables). + catalog (str, optional): The catalog name. Default is None. + schema (str, optional): The schema name pattern. Default is None. + tableType (str or list, optional): The table type filter. Default is None. + Example: "TABLE" or ["TABLE", "VIEW"] + + Returns: + list: A list of Row objects containing table information with these columns: + - table_cat: Catalog name + - table_schem: Schema name + - table_name: Table name + - table_type: Table type (e.g., "TABLE", "VIEW") + - remarks: Comments about the table + + Notes: + This method only processes the standard five columns as defined in the ODBC + specification. Any additional columns that might be returned by specific ODBC + drivers are not included in the result set. + + Example: + # Get all tables in the database + tables = cursor.tables() + + # Get all tables in schema 'dbo' + tables = cursor.tables(schema='dbo') + + # Get table named 'Customers' + tables = cursor.tables(table='Customers') + + # Get all views + tables = cursor.tables(tableType='VIEW') + """ + self._check_closed() + + # Clear messages + self.messages = [] + + # Always reset the cursor first to ensure clean state + self._reset_cursor() + + # Format table_type parameter - SQLTables expects comma-separated string + table_type_str = None + if tableType is not None: + if isinstance(tableType, (list, tuple)): + table_type_str = ",".join(tableType) + else: + table_type_str = str(tableType) + + # Call SQLTables via the helper method + self._execute_tables( + self.hstmt, + catalog_name=catalog, + schema_name=schema, + table_name=table, + table_type=table_type_str + ) + + # Initialize description from column metadata + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + except Exception: + # If describe fails, create a manual description for the standard columns + column_types = [str, str, str, str, str] + self.description = [ + ("table_cat", column_types[0], None, 128, 128, 0, True), + ("table_schem", column_types[1], None, 128, 128, 0, True), + ("table_name", column_types[2], None, 128, 128, 0, False), + ("table_type", column_types[3], None, 128, 128, 0, False), + ("remarks", column_types[4], None, 254, 254, 0, True) + ] + + # Define column names in ODBC standard order + column_names = [ + "table_cat", "table_schem", "table_name", "table_type", "remarks" + ] + + # Fetch all rows + rows_data = [] + ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data) + + # Create a column map for attribute access + column_map = {name: i for i, name in enumerate(column_names)} + + # Create Row objects with the column map + result_rows = [] + for row_data in rows_data: + row = Row(row_data, self.description, column_map) + result_rows.append(row) - # Skip the rows by fetching and discarding - for _ in range(count): - row = self.fetchone() - if row is None: - raise IndexError("Cannot skip beyond the end of the result set") + return result_rows \ No newline at end of file diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index a5bcc4466..b4f8f5e6d 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -134,6 +134,7 @@ SQLFreeStmtFunc SQLFreeStmt_ptr = nullptr; // Diagnostic APIs SQLGetDiagRecFunc SQLGetDiagRec_ptr = nullptr; +SQLTablesFunc SQLTables_ptr = nullptr; namespace { @@ -786,6 +787,7 @@ DriverHandle LoadDriverOrThrowException() { SQLFreeStmt_ptr = GetFunctionPointer(handle, "SQLFreeStmt"); SQLGetDiagRec_ptr = GetFunctionPointer(handle, "SQLGetDiagRecW"); + SQLTables_ptr = GetFunctionPointer(handle, "SQLTablesW"); bool success = SQLAllocHandle_ptr && SQLSetEnvAttr_ptr && SQLSetConnectAttr_ptr && @@ -796,7 +798,7 @@ DriverHandle LoadDriverOrThrowException() { SQLGetData_ptr && SQLNumResultCols_ptr && SQLBindCol_ptr && SQLDescribeCol_ptr && SQLMoreResults_ptr && SQLColAttribute_ptr && SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr && - SQLFreeStmt_ptr && SQLGetDiagRec_ptr; + SQLFreeStmt_ptr && SQLGetDiagRec_ptr && SQLTables_ptr; if (!success) { ThrowStdException("Failed to load required function pointers from driver."); @@ -982,6 +984,91 @@ SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Q return ret; } +// Wrapper for SQLTables +SQLRETURN SQLTables_wrap(SqlHandlePtr StatementHandle, + const std::wstring& catalog, + const std::wstring& schema, + const std::wstring& table, + const std::wstring& tableType) { + + if (!SQLTables_ptr) { + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); + } + + SQLWCHAR* catalogPtr = nullptr; + SQLWCHAR* schemaPtr = nullptr; + SQLWCHAR* tablePtr = nullptr; + SQLWCHAR* tableTypePtr = nullptr; + SQLSMALLINT catalogLen = 0; + SQLSMALLINT schemaLen = 0; + SQLSMALLINT tableLen = 0; + SQLSMALLINT tableTypeLen = 0; + + std::vector catalogBuffer; + std::vector schemaBuffer; + std::vector tableBuffer; + std::vector tableTypeBuffer; + +#if defined(__APPLE__) || defined(__linux__) + // On Unix platforms, convert wstring to SQLWCHAR array + if (!catalog.empty()) { + catalogBuffer = WStringToSQLWCHAR(catalog); + catalogPtr = catalogBuffer.data(); + catalogLen = SQL_NTS; + } + if (!schema.empty()) { + schemaBuffer = WStringToSQLWCHAR(schema); + schemaPtr = schemaBuffer.data(); + schemaLen = SQL_NTS; + } + if (!table.empty()) { + tableBuffer = WStringToSQLWCHAR(table); + tablePtr = tableBuffer.data(); + tableLen = SQL_NTS; + } + if (!tableType.empty()) { + tableTypeBuffer = WStringToSQLWCHAR(tableType); + tableTypePtr = tableTypeBuffer.data(); + tableTypeLen = SQL_NTS; + } +#else + // On Windows, direct assignment works + if (!catalog.empty()) { + catalogPtr = const_cast(catalog.c_str()); + catalogLen = SQL_NTS; + } + if (!schema.empty()) { + schemaPtr = const_cast(schema.c_str()); + schemaLen = SQL_NTS; + } + if (!table.empty()) { + tablePtr = const_cast(table.c_str()); + tableLen = SQL_NTS; + } + if (!tableType.empty()) { + tableTypePtr = const_cast(tableType.c_str()); + tableTypeLen = SQL_NTS; + } +#endif + + SQLRETURN ret = SQLTables_ptr( + StatementHandle->get(), + catalogPtr, catalogLen, + schemaPtr, schemaLen, + tablePtr, tableLen, + tableTypePtr, tableTypeLen + ); + + if (!SQL_SUCCEEDED(ret)) { + LOG("SQLTables failed with return code: {}", ret); + } else { + LOG("SQLTables succeeded"); + } + + return ret; +} + // Executes the provided query. If the query is parametrized, it prepares the statement and // binds the parameters. Otherwise, it executes the query directly. // 'usePrepare' parameter can be used to disable the prepare step for queries that might already @@ -2616,6 +2703,12 @@ PYBIND11_MODULE(ddbc_bindings, m) { m.def("DDBCSQLGetAllDiagRecords", &SQLGetAllDiagRecords, "Get all diagnostic records for a handle", py::arg("handle")); + // Add to PYBIND11_MODULE section + m.def("DDBCSQLTables", &SQLTables_wrap, + "Get table information using ODBC SQLTables", + py::arg("StatementHandle"), py::arg("catalog") = std::wstring(), + py::arg("schema") = std::wstring(), py::arg("table") = std::wstring(), + py::arg("tableType") = std::wstring()); // Add a version attribute m.attr("__version__") = "1.0.0"; diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 22bc524bd..1bb3efb02 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -105,7 +105,18 @@ typedef SQLRETURN (SQL_API* SQLDescribeColFunc)(SQLHSTMT, SQLUSMALLINT, SQLWCHAR typedef SQLRETURN (SQL_API* SQLMoreResultsFunc)(SQLHSTMT); typedef SQLRETURN (SQL_API* SQLColAttributeFunc)(SQLHSTMT, SQLUSMALLINT, SQLUSMALLINT, SQLPOINTER, SQLSMALLINT, SQLSMALLINT*, SQLPOINTER); - +typedef SQLRETURN (*SQLTablesFunc)( + SQLHSTMT StatementHandle, + SQLWCHAR* CatalogName, + SQLSMALLINT NameLength1, + SQLWCHAR* SchemaName, + SQLSMALLINT NameLength2, + SQLWCHAR* TableName, + SQLSMALLINT NameLength3, + SQLWCHAR* TableType, + SQLSMALLINT NameLength4 +); + // Transaction APIs typedef SQLRETURN (SQL_API* SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT); @@ -148,6 +159,7 @@ extern SQLBindColFunc SQLBindCol_ptr; extern SQLDescribeColFunc SQLDescribeCol_ptr; extern SQLMoreResultsFunc SQLMoreResults_ptr; extern SQLColAttributeFunc SQLColAttribute_ptr; +extern SQLTablesFunc SQLTables_ptr; // Transaction APIs extern SQLEndTranFunc SQLEndTran_ptr; diff --git a/mssql_python/row.py b/mssql_python/row.py index 2c88412de..bbea7fdeb 100644 --- a/mssql_python/row.py +++ b/mssql_python/row.py @@ -9,27 +9,27 @@ class Row: print(row.column_name) # Access by column name """ - def __init__(self, values, cursor_description): + def __init__(self, values, description, column_map=None): """ - Initialize a Row object with values and cursor description. + Initialize a Row object with values and description. Args: - values: List of values for this row - cursor_description: The cursor description containing column metadata + values: List of values for this row. + description: Description of the columns (from cursor.description). + column_map: Optional mapping of column names to indices. """ self._values = values + self._description = description - # TODO: ADO task - Optimize memory usage by sharing column map across rows - # Instead of storing the full cursor_description in each Row object: - # 1. Build the column map once at the cursor level after setting description - # 2. Pass only this map to each Row instance - # 3. Remove cursor_description from Row objects entirely - - # Create mapping of column names to indices - self._column_map = {} - for i, desc in enumerate(cursor_description): - if desc and desc[0]: # Ensure column name exists - self._column_map[desc[0]] = i + # Build column map if not provided + if column_map is None: + self._column_map = {} + for i, desc in enumerate(description): + col_name = desc[0] + self._column_map[col_name] = i + self._column_map[col_name.lower()] = i # Add lowercase for case-insensitivity + else: + self._column_map = column_map def __getitem__(self, index): """Allow accessing by numeric index: row[0]""" diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index b43097d6e..0e4ff1e02 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -1302,7 +1302,7 @@ def test_row_column_mapping(cursor, db_connection): assert getattr(row, "Complex Name!") == 42, "Complex column name access failed" # Test column map completeness - assert len(row._column_map) == 3, "Column map size incorrect" + assert len(row._column_map) >= 3, "Column map size incorrect" assert "FirstColumn" in row._column_map, "Column map missing CamelCase column" assert "Second_Column" in row._column_map, "Column map missing snake_case column" assert "Complex Name!" in row._column_map, "Column map missing complex name column" @@ -4404,6 +4404,374 @@ def test_cursor_messages_with_error(cursor): assert len(cursor.messages) == 1, "Should have only the new message" assert "After error" in cursor.messages[0][1], "Message should be from after the error" +def test_tables_setup(cursor, db_connection): + """Create test objects for tables method testing""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_tables_schema') EXEC('CREATE SCHEMA pytest_tables_schema')") + + # Drop tables if they exist to ensure clean state + cursor.execute("DROP TABLE IF EXISTS pytest_tables_schema.regular_table") + cursor.execute("DROP TABLE IF EXISTS pytest_tables_schema.another_table") + cursor.execute("DROP VIEW IF EXISTS pytest_tables_schema.test_view") + + # Create regular table + cursor.execute(""" + CREATE TABLE pytest_tables_schema.regular_table ( + id INT PRIMARY KEY, + name VARCHAR(100) + ) + """) + + # Create another table + cursor.execute(""" + CREATE TABLE pytest_tables_schema.another_table ( + id INT PRIMARY KEY, + description VARCHAR(200) + ) + """) + + # Create a view + cursor.execute(""" + CREATE VIEW pytest_tables_schema.test_view AS + SELECT id, name FROM pytest_tables_schema.regular_table + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_tables_all(cursor, db_connection): + """Test tables returns information about all tables/views""" + try: + # First set up our test tables + test_tables_setup(cursor, db_connection) + + # Get all tables (no filters) + tables_list = cursor.tables() + + # Verify we got results + assert tables_list is not None, "tables() should return results" + assert len(tables_list) > 0, "tables() should return at least one table" + + # Verify our test tables are in the results + # Use case-insensitive comparison to avoid driver case sensitivity issues + found_test_table = False + for table in tables_list: + if (hasattr(table, 'table_name') and + table.table_name and + table.table_name.lower() == 'regular_table' and + hasattr(table, 'table_schem') and + table.table_schem and + table.table_schem.lower() == 'pytest_tables_schema'): + found_test_table = True + break + + assert found_test_table, "Test table should be included in results" + + # Verify structure of results + first_row = tables_list[0] + assert hasattr(first_row, 'table_cat'), "Result should have table_cat column" + assert hasattr(first_row, 'table_schem'), "Result should have table_schem column" + assert hasattr(first_row, 'table_name'), "Result should have table_name column" + assert hasattr(first_row, 'table_type'), "Result should have table_type column" + assert hasattr(first_row, 'remarks'), "Result should have remarks column" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_specific_table(cursor, db_connection): + """Test tables returns information about a specific table""" + try: + # Get specific table + tables_list = cursor.tables( + table='regular_table', + schema='pytest_tables_schema' + ) + + # Verify we got the right result + assert len(tables_list) == 1, "Should find exactly 1 table" + + # Verify table details + table = tables_list[0] + assert table.table_name.lower() == 'regular_table', "Table name should be 'regular_table'" + assert table.table_schem.lower() == 'pytest_tables_schema', "Schema should be 'pytest_tables_schema'" + assert table.table_type == 'TABLE', "Table type should be 'TABLE'" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_with_table_pattern(cursor, db_connection): + """Test tables with table name pattern""" + try: + # Get tables with pattern + tables_list = cursor.tables( + table='%table', + schema='pytest_tables_schema' + ) + + # Should find both test tables + assert len(tables_list) == 2, "Should find 2 tables matching '%table'" + + # Verify we found both test tables + table_names = set() + for table in tables_list: + if table.table_name: + table_names.add(table.table_name.lower()) + + assert 'regular_table' in table_names, "Should find regular_table" + assert 'another_table' in table_names, "Should find another_table" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_with_schema_pattern(cursor, db_connection): + """Test tables with schema name pattern""" + try: + # Get tables with schema pattern + tables_list = cursor.tables( + schema='pytest_%' + ) + + # Should find our test tables/view + test_tables = [] + for table in tables_list: + if (table.table_schem and + table.table_schem.lower() == 'pytest_tables_schema' and + table.table_name and + table.table_name.lower() in ('regular_table', 'another_table', 'test_view')): + test_tables.append(table.table_name.lower()) + + assert len(test_tables) == 3, "Should find our 3 test objects" + assert 'regular_table' in test_tables, "Should find regular_table" + assert 'another_table' in test_tables, "Should find another_table" + assert 'test_view' in test_tables, "Should find test_view" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_with_type_filter(cursor, db_connection): + """Test tables with table type filter""" + try: + # Get only tables + tables_list = cursor.tables( + schema='pytest_tables_schema', + tableType='TABLE' + ) + + # Verify only regular tables + table_types = set() + table_names = set() + for table in tables_list: + if table.table_type: + table_types.add(table.table_type) + if table.table_name: + table_names.add(table.table_name.lower()) + + assert len(table_types) == 1, "Should only have one table type" + assert 'TABLE' in table_types, "Should only find TABLE type" + assert 'regular_table' in table_names, "Should find regular_table" + assert 'another_table' in table_names, "Should find another_table" + assert 'test_view' not in table_names, "Should not find test_view" + + # Get only views + views_list = cursor.tables( + schema='pytest_tables_schema', + tableType='VIEW' + ) + + # Verify only views + view_names = set() + for view in views_list: + if view.table_name: + view_names.add(view.table_name.lower()) + + assert 'test_view' in view_names, "Should find test_view" + assert 'regular_table' not in view_names, "Should not find regular_table" + assert 'another_table' not in view_names, "Should not find another_table" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_with_multiple_types(cursor, db_connection): + """Test tables with multiple table types""" + try: + # Get both tables and views + tables_list = cursor.tables( + schema='pytest_tables_schema', + tableType=['TABLE', 'VIEW'] + ) + + # Verify both tables and views + object_names = set() + for obj in tables_list: + if obj.table_name: + object_names.add(obj.table_name.lower()) + + assert len(object_names) == 3, "Should find 3 objects (2 tables + 1 view)" + assert 'regular_table' in object_names, "Should find regular_table" + assert 'another_table' in object_names, "Should find another_table" + assert 'test_view' in object_names, "Should find test_view" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_catalog_filter(cursor, db_connection): + """Test tables with catalog filter""" + try: + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + current_db = cursor.fetchone().current_db + + # Get tables with current catalog + tables_list = cursor.tables( + catalog=current_db, + schema='pytest_tables_schema' + ) + + # Verify catalog filter worked + assert len(tables_list) > 0, "Should find tables with correct catalog" + + # Verify catalog in results + for table in tables_list: + # Some drivers might return None for catalog + if table.table_cat is not None: + assert table.table_cat.lower() == current_db.lower(), "Wrong table catalog" + + # Test with non-existent catalog + fake_tables = cursor.tables( + catalog='nonexistent_db_xyz123', + schema='pytest_tables_schema' + ) + assert len(fake_tables) == 0, "Should return empty list for non-existent catalog" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_nonexistent(cursor): + """Test tables with non-existent objects""" + # Test with non-existent table + tables_list = cursor.tables(table='nonexistent_table_xyz123') + + # Should return empty list, not error + assert isinstance(tables_list, list), "Should return a list for non-existent table" + assert len(tables_list) == 0, "Should return empty list for non-existent table" + + # Test with non-existent schema + tables_list = cursor.tables( + table='regular_table', + schema='nonexistent_schema_xyz123' + ) + assert len(tables_list) == 0, "Should return empty list for non-existent schema" + +def test_tables_combined_filters(cursor, db_connection): + """Test tables with multiple combined filters""" + try: + # Test with schema and table pattern + tables_list = cursor.tables( + schema='pytest_tables_schema', + table='regular%' + ) + + # Should find only regular_table + assert len(tables_list) == 1, "Should find 1 table with combined filters" + assert tables_list[0].table_name.lower() == 'regular_table', "Should find regular_table" + + # Test with schema, table pattern, and type + tables_list = cursor.tables( + schema='pytest_tables_schema', + table='%table', + tableType='TABLE' + ) + + # Should find both tables but not view + table_names = set() + for table in tables_list: + if table.table_name: + table_names.add(table.table_name.lower()) + + assert len(table_names) == 2, "Should find 2 tables with combined filters" + assert 'regular_table' in table_names, "Should find regular_table" + assert 'another_table' in table_names, "Should find another_table" + assert 'test_view' not in table_names, "Should not find test_view" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_result_processing(cursor, db_connection): + """Test processing of tables result set for different client needs""" + try: + # Get all test objects + tables_list = cursor.tables(schema='pytest_tables_schema') + + # Test 1: Extract just table names + table_names = [table.table_name for table in tables_list] + assert len(table_names) == 3, "Should extract 3 table names" + + # Test 2: Filter to just tables (not views) + just_tables = [table for table in tables_list if table.table_type == 'TABLE'] + assert len(just_tables) == 2, "Should find 2 regular tables" + + # Test 3: Create a schema.table dictionary + schema_table_map = {} + for table in tables_list: + if table.table_schem not in schema_table_map: + schema_table_map[table.table_schem] = [] + schema_table_map[table.table_schem].append(table.table_name) + + assert 'pytest_tables_schema' in schema_table_map, "Should have our test schema" + assert len(schema_table_map['pytest_tables_schema']) == 3, "Should have 3 objects in test schema" + + # Test 4: Check indexing and attribute access + first_table = tables_list[0] + assert first_table[0] == first_table.table_cat, "Index 0 should match table_cat attribute" + assert first_table[1] == first_table.table_schem, "Index 1 should match table_schem attribute" + assert first_table[2] == first_table.table_name, "Index 2 should match table_name attribute" + assert first_table[3] == first_table.table_type, "Index 3 should match table_type attribute" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_method_chaining(cursor, db_connection): + """Test tables method with method chaining""" + try: + # Test method chaining with other methods + chained_result = cursor.tables( + schema='pytest_tables_schema', + table='regular_table' + ) + + # Verify chained result + assert len(chained_result) == 1, "Chained result should find 1 table" + assert chained_result[0].table_name.lower() == 'regular_table', "Should find regular_table" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_cleanup(cursor, db_connection): + """Clean up test objects after testing""" + try: + # Drop all test objects + cursor.execute("DROP VIEW IF EXISTS pytest_tables_schema.test_view") + cursor.execute("DROP TABLE IF EXISTS pytest_tables_schema.regular_table") + cursor.execute("DROP TABLE IF EXISTS pytest_tables_schema.another_table") + + # Drop the test schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_tables_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Test cleanup failed: {e}") + def test_close(db_connection): """Test closing the cursor""" try: