diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 59eb33066..5aa2efb3d 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -66,6 +66,7 @@ def _validate_encoding(encoding: str) -> bool: ProgrammingError, NotSupportedError, ) +from mssql_python.constants import GetInfoConstants class Connection: @@ -544,6 +545,30 @@ def getdecoding(self, sqltype): return self._decoding_settings[sqltype].copy() + @property + def searchescape(self): + """ + The ODBC search pattern escape character, as returned by + SQLGetInfo(SQL_SEARCH_PATTERN_ESCAPE), used to escape special characters + such as '%' and '_' in LIKE clauses. These are driver specific. + + Returns: + str: The search pattern escape character (usually '\' or another character) + """ + if not hasattr(self, '_searchescape'): + try: + escape_char = self.getinfo(GetInfoConstants.SQL_SEARCH_PATTERN_ESCAPE.value) + # Some drivers might return this as an integer memory address + # or other non-string format, so ensure we have a string + if not isinstance(escape_char, str): + escape_char = '\\' # Default to backslash if not a string + self._searchescape = escape_char + except Exception as e: + # Log the exception for debugging, but do not expose sensitive info + log('warning', f"Failed to retrieve search escape character, using default '\\'. Exception: {type(e).__name__}") + self._searchescape = '\\' + return self._searchescape + def cursor(self) -> Cursor: """ Return a new Cursor object using the connection. @@ -824,6 +849,30 @@ def batch_execute(self, statements, params=None, reuse_cursor=None, auto_close=F log('debug', "Automatically closed cursor after batch execution") return results, cursor + + def getinfo(self, info_type): + """ + Return general information about the driver and data source. + + Args: + info_type (int): The type of information to return. See the ODBC + SQLGetInfo documentation for the supported values. + + Returns: + The requested information. The type of the returned value depends + on the information requested. It will be a string, integer, or boolean. + + Raises: + DatabaseError: If there is an error retrieving the information. + InterfaceError: If the connection is closed. + """ + if self._closed: + raise InterfaceError( + driver_error="Cannot get info on closed connection", + ddbc_error="Cannot get info on closed connection", + ) + + return self._conn.get_info(info_type) def commit(self) -> None: """ diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 61380e1f3..2e56112e5 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -134,6 +134,144 @@ class ConstantsDDBC(Enum): SQL_QUICK = 0 SQL_ENSURE = 1 +class GetInfoConstants(Enum): + """ + These constants are used with various methods like getinfo(). + """ + + # Driver and database information + SQL_DRIVER_NAME = 6 + SQL_DRIVER_VER = 7 + SQL_DRIVER_ODBC_VER = 77 + SQL_DRIVER_HLIB = 76 + SQL_DRIVER_HENV = 75 + SQL_DRIVER_HDBC = 74 + SQL_DATA_SOURCE_NAME = 2 + SQL_DATABASE_NAME = 16 + SQL_SERVER_NAME = 13 + SQL_USER_NAME = 47 + + # SQL conformance and support + SQL_SQL_CONFORMANCE = 118 + SQL_KEYWORDS = 89 + SQL_IDENTIFIER_CASE = 28 + SQL_IDENTIFIER_QUOTE_CHAR = 29 + SQL_SPECIAL_CHARACTERS = 94 + SQL_SQL92_ENTRY_SQL = 127 + SQL_SQL92_INTERMEDIATE_SQL = 128 + SQL_SQL92_FULL_SQL = 129 + SQL_SUBQUERIES = 95 + SQL_EXPRESSIONS_IN_ORDERBY = 27 + SQL_CORRELATION_NAME = 74 + SQL_SEARCH_PATTERN_ESCAPE = 14 + + # Catalog and schema support + SQL_CATALOG_TERM = 42 + SQL_CATALOG_NAME_SEPARATOR = 41 + SQL_SCHEMA_TERM = 39 + SQL_TABLE_TERM = 45 + SQL_PROCEDURES = 21 + SQL_ACCESSIBLE_TABLES = 19 + SQL_ACCESSIBLE_PROCEDURES = 20 + SQL_CATALOG_NAME = 10002 + SQL_CATALOG_USAGE = 92 + SQL_SCHEMA_USAGE = 91 + SQL_COLUMN_ALIAS = 87 + SQL_DESCRIBE_PARAMETER = 10003 + + # Transaction support + SQL_TXN_CAPABLE = 46 + SQL_TXN_ISOLATION_OPTION = 72 + SQL_DEFAULT_TXN_ISOLATION = 26 + SQL_MULTIPLE_ACTIVE_TXN = 37 + SQL_TXN_ISOLATION_LEVEL = 108 + + # Data type support + SQL_NUMERIC_FUNCTIONS = 49 + SQL_STRING_FUNCTIONS = 50 + SQL_DATETIME_FUNCTIONS = 51 + SQL_SYSTEM_FUNCTIONS = 58 + SQL_CONVERT_FUNCTIONS = 48 + SQL_LIKE_ESCAPE_CLAUSE = 113 + + # Numeric limits + SQL_MAX_COLUMN_NAME_LEN = 30 + SQL_MAX_TABLE_NAME_LEN = 35 + SQL_MAX_SCHEMA_NAME_LEN = 32 + SQL_MAX_CATALOG_NAME_LEN = 34 + SQL_MAX_IDENTIFIER_LEN = 10005 + SQL_MAX_STATEMENT_LEN = 105 + SQL_MAX_CHAR_LITERAL_LEN = 108 + SQL_MAX_BINARY_LITERAL_LEN = 112 + SQL_MAX_COLUMNS_IN_TABLE = 101 + SQL_MAX_COLUMNS_IN_SELECT = 100 + SQL_MAX_COLUMNS_IN_GROUP_BY = 97 + SQL_MAX_COLUMNS_IN_ORDER_BY = 99 + SQL_MAX_COLUMNS_IN_INDEX = 98 + SQL_MAX_TABLES_IN_SELECT = 106 + SQL_MAX_CONCURRENT_ACTIVITIES = 1 + SQL_MAX_DRIVER_CONNECTIONS = 0 + SQL_MAX_ROW_SIZE = 104 + SQL_MAX_USER_NAME_LEN = 107 + + # Connection attributes + SQL_ACTIVE_CONNECTIONS = 0 + SQL_ACTIVE_STATEMENTS = 1 + SQL_DATA_SOURCE_READ_ONLY = 25 + SQL_NEED_LONG_DATA_LEN = 111 + SQL_GETDATA_EXTENSIONS = 81 + + # Result set and cursor attributes + SQL_CURSOR_COMMIT_BEHAVIOR = 23 + SQL_CURSOR_ROLLBACK_BEHAVIOR = 24 + SQL_CURSOR_SENSITIVITY = 10001 + SQL_BOOKMARK_PERSISTENCE = 82 + SQL_DYNAMIC_CURSOR_ATTRIBUTES1 = 144 + SQL_DYNAMIC_CURSOR_ATTRIBUTES2 = 145 + SQL_FORWARD_ONLY_CURSOR_ATTRIBUTES1 = 146 + SQL_FORWARD_ONLY_CURSOR_ATTRIBUTES2 = 147 + SQL_STATIC_CURSOR_ATTRIBUTES1 = 150 + SQL_STATIC_CURSOR_ATTRIBUTES2 = 151 + SQL_KEYSET_CURSOR_ATTRIBUTES1 = 148 + SQL_KEYSET_CURSOR_ATTRIBUTES2 = 149 + SQL_SCROLL_OPTIONS = 44 + SQL_SCROLL_CONCURRENCY = 43 + SQL_FETCH_DIRECTION = 8 + SQL_ROWSET_SIZE = 9 + SQL_CONCURRENCY = 7 + SQL_ROW_NUMBER = 14 + SQL_STATIC_SENSITIVITY = 83 + SQL_BATCH_SUPPORT = 121 + SQL_BATCH_ROW_COUNT = 120 + SQL_PARAM_ARRAY_ROW_COUNTS = 153 + SQL_PARAM_ARRAY_SELECTS = 154 + + # Positioned statement support + SQL_POSITIONED_STATEMENTS = 80 + + # Other constants + SQL_GROUP_BY = 88 + SQL_OJ_CAPABILITIES = 65 + SQL_ORDER_BY_COLUMNS_IN_SELECT = 90 + SQL_OUTER_JOINS = 38 + SQL_QUOTED_IDENTIFIER_CASE = 93 + SQL_CONCAT_NULL_BEHAVIOR = 22 + SQL_NULL_COLLATION = 85 + SQL_ALTER_TABLE = 86 + SQL_UNION = 96 + SQL_DDL_INDEX = 170 + SQL_MULT_RESULT_SETS = 36 + SQL_OWNER_USAGE = 91 + SQL_QUALIFIER_USAGE = 92 + SQL_TIMEDATE_ADD_INTERVALS = 109 + SQL_TIMEDATE_DIFF_INTERVALS = 110 + + # Return values for some getinfo functions + SQL_IC_UPPER = 1 + SQL_IC_LOWER = 2 + SQL_IC_SENSITIVE = 3 + SQL_IC_MIXED = 4 + class AuthType(Enum): """Constants for authentication types""" INTERACTIVE = "activedirectoryinteractive" diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index 9782efd22..d28ee7c50 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -10,6 +10,7 @@ #include #define SQL_COPT_SS_ACCESS_TOKEN 1256 // Custom attribute ID for access token +#define SQL_MAX_SMALL_INT 32767 // Maximum value for SQLSMALLINT static SqlHandlePtr getEnvHandle() { static SqlHandlePtr envHandle = []() -> SqlHandlePtr { @@ -314,4 +315,195 @@ SqlHandlePtr ConnectionHandle::allocStatementHandle() { ThrowStdException("Connection object is not initialized"); } return _conn->allocStatementHandle(); +} + +py::object Connection::getInfo(SQLUSMALLINT infoType) const { + if (!_dbcHandle) { + ThrowStdException("Connection handle not allocated"); + } + + LOG("Getting connection info for type {}", infoType); + + // Use a vector for dynamic sizing + std::vector buffer(1024, 0); + SQLSMALLINT actualLength = 0; + SQLRETURN ret; + + // First try to get the info - handle SQLSMALLINT size limit + SQLSMALLINT bufferSize = (buffer.size() <= SQL_MAX_SMALL_INT) + ? static_cast(buffer.size()) + : SQL_MAX_SMALL_INT; + + ret = SQLGetInfo_ptr(_dbcHandle->get(), infoType, buffer.data(), bufferSize, &actualLength); + + // If truncation occurred (actualLength >= bufferSize means truncation) + if (SQL_SUCCEEDED(ret) && actualLength >= bufferSize) { + // Resize buffer to the needed size (add 1 for null terminator) + buffer.resize(actualLength + 1, 0); + + // Call again with the larger buffer - handle SQLSMALLINT size limit again + bufferSize = (buffer.size() <= SQL_MAX_SMALL_INT) + ? static_cast(buffer.size()) + : SQL_MAX_SMALL_INT; + + ret = SQLGetInfo_ptr(_dbcHandle->get(), infoType, buffer.data(), bufferSize, &actualLength); + } + + // Check for errors + if (!SQL_SUCCEEDED(ret)) { + checkError(ret); + } + + // Note: This implementation assumes the ODBC driver handles any necessary + // endianness conversions between the database server and the client. + + // Determine return type based on the InfoType + // String types usually have InfoType > 10000 or are specifically known string values + if (infoType > 10000 || + infoType == SQL_DATA_SOURCE_NAME || + infoType == SQL_DBMS_NAME || + infoType == SQL_DBMS_VER || + infoType == SQL_DRIVER_NAME || + infoType == SQL_DRIVER_VER || + // Add missing string types + infoType == SQL_IDENTIFIER_QUOTE_CHAR || + infoType == SQL_CATALOG_NAME_SEPARATOR || + infoType == SQL_CATALOG_TERM || + infoType == SQL_SCHEMA_TERM || + infoType == SQL_TABLE_TERM || + infoType == SQL_KEYWORDS || + infoType == SQL_PROCEDURE_TERM) { + // Return as string + return py::str(buffer.data()); + } + else if (infoType == SQL_DRIVER_ODBC_VER || + infoType == SQL_SERVER_NAME) { + // Return as string + return py::str(buffer.data()); + } + else { + // For numeric types, safely extract values + + // Ensure buffer has enough data for the expected type + switch (infoType) { + // 16-bit unsigned integers + case SQL_MAX_CONCURRENT_ACTIVITIES: + case SQL_MAX_DRIVER_CONNECTIONS: + case SQL_ODBC_API_CONFORMANCE: + case SQL_ODBC_SQL_CONFORMANCE: + case SQL_TXN_CAPABLE: + case SQL_MULTIPLE_ACTIVE_TXN: + case SQL_MAX_COLUMN_NAME_LEN: + case SQL_MAX_TABLE_NAME_LEN: + case SQL_PROCEDURES: + { + if (actualLength >= sizeof(SQLUSMALLINT) && buffer.size() >= sizeof(SQLUSMALLINT)) { + SQLUSMALLINT value = 0; + // Safely copy data by using std::copy instead of memcpy + std::copy(buffer.begin(), buffer.begin() + sizeof(SQLUSMALLINT), + reinterpret_cast(&value)); + return py::int_(value); + } + break; + } + + // 32-bit unsigned integers + case SQL_ASYNC_MODE: + case SQL_GETDATA_EXTENSIONS: + case SQL_MAX_ASYNC_CONCURRENT_STATEMENTS: + case SQL_MAX_COLUMNS_IN_GROUP_BY: + case SQL_MAX_COLUMNS_IN_ORDER_BY: + case SQL_MAX_COLUMNS_IN_SELECT: + case SQL_MAX_COLUMNS_IN_TABLE: + case SQL_MAX_ROW_SIZE: + case SQL_MAX_TABLES_IN_SELECT: + case SQL_MAX_USER_NAME_LEN: + case SQL_NUMERIC_FUNCTIONS: + case SQL_STRING_FUNCTIONS: + case SQL_SYSTEM_FUNCTIONS: + case SQL_TIMEDATE_FUNCTIONS: + case SQL_DEFAULT_TXN_ISOLATION: + case SQL_MAX_STATEMENT_LEN: + { + if (actualLength >= sizeof(SQLUINTEGER) && buffer.size() >= sizeof(SQLUINTEGER)) { + SQLUINTEGER value = 0; + // Safely copy data by using std::copy instead of memcpy + std::copy(buffer.begin(), buffer.begin() + sizeof(SQLUINTEGER), + reinterpret_cast(&value)); + return py::int_(value); + } + break; + } + + // Boolean flags (32-bit mask) + case SQL_AGGREGATE_FUNCTIONS: + case SQL_ALTER_TABLE: + case SQL_CATALOG_USAGE: + case SQL_DATETIME_LITERALS: + case SQL_INDEX_KEYWORDS: + case SQL_INSERT_STATEMENT: + case SQL_SCHEMA_USAGE: + case SQL_SQL_CONFORMANCE: + case SQL_SQL92_DATETIME_FUNCTIONS: + case SQL_SQL92_NUMERIC_VALUE_FUNCTIONS: + case SQL_SQL92_PREDICATES: + case SQL_SQL92_RELATIONAL_JOIN_OPERATORS: + case SQL_SQL92_STRING_FUNCTIONS: + case SQL_STATIC_CURSOR_ATTRIBUTES1: + case SQL_STATIC_CURSOR_ATTRIBUTES2: + { + if (actualLength >= sizeof(SQLUINTEGER) && buffer.size() >= sizeof(SQLUINTEGER)) { + SQLUINTEGER value = 0; + // Safely copy data by using std::copy instead of memcpy + std::copy(buffer.begin(), buffer.begin() + sizeof(SQLUINTEGER), + reinterpret_cast(&value)); + return py::int_(value); + } + break; + } + + // Handle any other types as integers, if enough data + default: + if (actualLength >= sizeof(SQLUINTEGER) && buffer.size() >= sizeof(SQLUINTEGER)) { + SQLUINTEGER value = 0; + // Safely copy data by using std::copy instead of memcpy + std::copy(buffer.begin(), buffer.begin() + sizeof(SQLUINTEGER), + reinterpret_cast(&value)); + return py::int_(value); + } + else if (actualLength >= sizeof(SQLUSMALLINT) && buffer.size() >= sizeof(SQLUSMALLINT)) { + SQLUSMALLINT value = 0; + // Safely copy data by using std::copy instead of memcpy + std::copy(buffer.begin(), buffer.begin() + sizeof(SQLUSMALLINT), + reinterpret_cast(&value)); + return py::int_(value); + } + // For very small integers (like bytes/chars) + else if (actualLength > 0 && buffer.size() >= sizeof(unsigned char)) { + // Try to interpret as a small integer + unsigned char value = 0; + // Safely copy data by using std::copy instead of memcpy + std::copy(buffer.begin(), buffer.begin() + sizeof(unsigned char), + reinterpret_cast(&value)); + return py::int_(value); + } + break; + } + } + + // If we get here and actualLength > 0, try to return as string as a last resort + if (actualLength > 0) { + return py::str(buffer.data()); + } + + // Default return in case nothing matched or buffer is too small + LOG("Unable to convert result for info type {}", infoType); + return py::none(); +} + +py::object ConnectionHandle::getInfo(SQLUSMALLINT infoType) const { + if (!_conn) { + ThrowStdException("Connection object is not initialized"); + } + return _conn->getInfo(infoType); } \ No newline at end of file diff --git a/mssql_python/pybind/connection/connection.h b/mssql_python/pybind/connection/connection.h index 6129125e1..66dd58952 100644 --- a/mssql_python/pybind/connection/connection.h +++ b/mssql_python/pybind/connection/connection.h @@ -42,6 +42,9 @@ class Connection { // Allocate a new statement handle on this connection. SqlHandlePtr allocStatementHandle(); + // Get information about the driver and data source + py::object getInfo(SQLUSMALLINT infoType) const; + private: void allocateDbcHandle(); void checkError(SQLRETURN ret) const; @@ -67,6 +70,9 @@ class ConnectionHandle { bool getAutocommit() const; SqlHandlePtr allocStatementHandle(); + // Get information about the driver and data source + py::object getInfo(SQLUSMALLINT infoType) const; + private: std::shared_ptr _conn; bool _usePool; diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index c7f7aefb6..fe8197fdc 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -131,6 +131,7 @@ SQLPrimaryKeysFunc SQLPrimaryKeys_ptr = nullptr; SQLSpecialColumnsFunc SQLSpecialColumns_ptr = nullptr; SQLStatisticsFunc SQLStatistics_ptr = nullptr; SQLColumnsFunc SQLColumns_ptr = nullptr; +SQLGetInfoFunc SQLGetInfo_ptr = nullptr; // Transaction APIs SQLEndTranFunc SQLEndTran_ptr = nullptr; @@ -827,6 +828,7 @@ DriverHandle LoadDriverOrThrowException() { SQLSpecialColumns_ptr = GetFunctionPointer(handle, "SQLSpecialColumnsW"); SQLStatistics_ptr = GetFunctionPointer(handle, "SQLStatisticsW"); SQLColumns_ptr = GetFunctionPointer(handle, "SQLColumnsW"); + SQLGetInfo_ptr = GetFunctionPointer(handle, "SQLGetInfoW"); SQLEndTran_ptr = GetFunctionPointer(handle, "SQLEndTran"); SQLDisconnect_ptr = GetFunctionPointer(handle, "SQLDisconnect"); @@ -850,7 +852,7 @@ DriverHandle LoadDriverOrThrowException() { SQLGetData_ptr && SQLNumResultCols_ptr && SQLBindCol_ptr && SQLDescribeCol_ptr && SQLMoreResults_ptr && SQLColAttribute_ptr && SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr && - SQLFreeStmt_ptr && SQLGetDiagRec_ptr && SQLParamData_ptr && + SQLFreeStmt_ptr && SQLGetDiagRec_ptr && SQLGetInfo_ptr && SQLParamData_ptr && SQLPutData_ptr && SQLTables_ptr && SQLDescribeParam_ptr && SQLGetTypeInfo_ptr && SQLProcedures_ptr && SQLForeignKeys_ptr && @@ -3422,7 +3424,8 @@ PYBIND11_MODULE(ddbc_bindings, m) { .def("rollback", &ConnectionHandle::rollback, "Rollback the current transaction") .def("set_autocommit", &ConnectionHandle::setAutocommit) .def("get_autocommit", &ConnectionHandle::getAutocommit) - .def("alloc_statement_handle", &ConnectionHandle::allocStatementHandle); + .def("alloc_statement_handle", &ConnectionHandle::allocStatementHandle) + .def("get_info", &ConnectionHandle::getInfo, py::arg("info_type")); m.def("enable_pooling", &enable_pooling, "Enable global connection pooling"); m.def("close_pooling", []() {ConnectionPoolManager::getInstance().closePools();}); m.def("DDBCSQLExecDirect", &SQLExecDirect_wrap, "Execute a SQL query directly"); diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 63bee5439..176724a40 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -225,6 +225,7 @@ typedef SQLRETURN (SQL_API* SQLStatisticsFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, typedef SQLRETURN (SQL_API* SQLColumnsFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT); +typedef SQLRETURN (SQL_API* SQLGetInfoFunc)(SQLHDBC, SQLUSMALLINT, SQLPOINTER, SQLSMALLINT, SQLSMALLINT*); // Transaction APIs typedef SQLRETURN (SQL_API* SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT); @@ -281,6 +282,7 @@ extern SQLPrimaryKeysFunc SQLPrimaryKeys_ptr; extern SQLSpecialColumnsFunc SQLSpecialColumns_ptr; extern SQLStatisticsFunc SQLStatistics_ptr; extern SQLColumnsFunc SQLColumns_ptr; +extern SQLGetInfoFunc SQLGetInfo_ptr; // Transaction APIs extern SQLEndTranFunc SQLEndTran_ptr; diff --git a/tests/test_001_globals.py b/tests/test_001_globals.py index 2f1e37540..85dc6e298 100644 --- a/tests/test_001_globals.py +++ b/tests/test_001_globals.py @@ -28,37 +28,6 @@ def test_paramstyle(): # Check if paramstyle has the expected value assert paramstyle == "qmark", "paramstyle should be 'qmark'" -def test_lowercase(): - # Check if lowercase has the expected default value - assert lowercase is False, "lowercase should default to False" - -def test_decimal_separator(): - """Test decimal separator functionality""" - - # Check default value - assert getDecimalSeparator() == '.', "Default decimal separator should be '.'" - - try: - # Test setting a new value - setDecimalSeparator(',') - assert getDecimalSeparator() == ',', "Decimal separator should be ',' after setting" - - # Test invalid input - with pytest.raises(ValueError): - setDecimalSeparator('too long') - - with pytest.raises(ValueError): - setDecimalSeparator('') - - with pytest.raises(ValueError): - setDecimalSeparator(123) # Non-string input - - finally: - # Restore default value - setDecimalSeparator('.') - assert getDecimalSeparator() == '.', "Decimal separator should be restored to '.'" - - def test_lowercase(): # Check if lowercase has the expected default value assert lowercase is False, "lowercase should default to False" diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index ac77be9dc..fe2625c7f 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -65,6 +65,7 @@ def clean_connection_state(db_connection): cleanup_cursor.close() except Exception: pass # Ignore errors during cleanup +from mssql_python.constants import GetInfoConstants as sql_const def drop_table_if_exists(cursor, table_name): """Drop the table if it exists""" @@ -4742,4 +4743,379 @@ def test_timeout_affects_all_cursors(db_connection): # with the current timeout setting finally: # Reset timeout - db_connection.timeout = original_timeout \ No newline at end of file + db_connection.timeout = original_timeout +def test_getinfo_basic_driver_info(db_connection): + """Test basic driver information info types.""" + + try: + # Driver name should be available + driver_name = db_connection.getinfo(sql_const.SQL_DRIVER_NAME.value) + print("Driver Name = ",driver_name) + assert driver_name is not None, "Driver name should not be None" + + # Driver version should be available + driver_ver = db_connection.getinfo(sql_const.SQL_DRIVER_VER.value) + print("Driver Version = ",driver_ver) + assert driver_ver is not None, "Driver version should not be None" + + # Data source name should be available + dsn = db_connection.getinfo(sql_const.SQL_DATA_SOURCE_NAME.value) + print("Data source name = ",dsn) + assert dsn is not None, "Data source name should not be None" + + # Server name should be available (might be empty in some configurations) + server_name = db_connection.getinfo(sql_const.SQL_SERVER_NAME.value) + print("Server Name = ",server_name) + assert server_name is not None, "Server name should not be None" + + # User name should be available (might be empty if using integrated auth) + user_name = db_connection.getinfo(sql_const.SQL_USER_NAME.value) + print("User Name = ",user_name) + assert user_name is not None, "User name should not be None" + + except Exception as e: + pytest.fail(f"getinfo failed for basic driver info: {e}") + +def test_getinfo_sql_support(db_connection): + """Test SQL support and conformance info types.""" + + try: + # SQL conformance level + sql_conformance = db_connection.getinfo(sql_const.SQL_SQL_CONFORMANCE.value) + print("SQL Conformance = ",sql_conformance) + assert sql_conformance is not None, "SQL conformance should not be None" + + # Keywords - may return a very long string + keywords = db_connection.getinfo(sql_const.SQL_KEYWORDS.value) + print("Keywords = ",keywords) + assert keywords is not None, "SQL keywords should not be None" + + # Identifier quote character + quote_char = db_connection.getinfo(sql_const.SQL_IDENTIFIER_QUOTE_CHAR.value) + print(f"Identifier quote char: '{quote_char}'") + assert quote_char is not None, "Identifier quote char should not be None" + + except Exception as e: + pytest.fail(f"getinfo failed for SQL support info: {e}") + +def test_getinfo_numeric_limits(db_connection): + """Test numeric limitation info types.""" + + try: + # Max column name length - should be a positive integer + max_col_name_len = db_connection.getinfo(sql_const.SQL_MAX_COLUMN_NAME_LEN.value) + assert isinstance(max_col_name_len, int), "Max column name length should be an integer" + assert max_col_name_len >= 0, "Max column name length should be non-negative" + + # Max table name length + max_table_name_len = db_connection.getinfo(sql_const.SQL_MAX_TABLE_NAME_LEN.value) + assert isinstance(max_table_name_len, int), "Max table name length should be an integer" + assert max_table_name_len >= 0, "Max table name length should be non-negative" + + # Max statement length - may return 0 for "unlimited" + max_statement_len = db_connection.getinfo(sql_const.SQL_MAX_STATEMENT_LEN.value) + assert isinstance(max_statement_len, int), "Max statement length should be an integer" + assert max_statement_len >= 0, "Max statement length should be non-negative" + + # Max connections - may return 0 for "unlimited" + max_connections = db_connection.getinfo(sql_const.SQL_MAX_DRIVER_CONNECTIONS.value) + assert isinstance(max_connections, int), "Max connections should be an integer" + assert max_connections >= 0, "Max connections should be non-negative" + + except Exception as e: + pytest.fail(f"getinfo failed for numeric limits info: {e}") + +def test_getinfo_catalog_support(db_connection): + """Test catalog support info types.""" + + try: + # Catalog support for tables + catalog_term = db_connection.getinfo(sql_const.SQL_CATALOG_TERM.value) + print("Catalog term = ",catalog_term) + assert catalog_term is not None, "Catalog term should not be None" + + # Catalog name separator + catalog_separator = db_connection.getinfo(sql_const.SQL_CATALOG_NAME_SEPARATOR.value) + print(f"Catalog name separator: '{catalog_separator}'") + assert catalog_separator is not None, "Catalog separator should not be None" + + # Schema term + schema_term = db_connection.getinfo(sql_const.SQL_SCHEMA_TERM.value) + print("Schema term = ",schema_term) + assert schema_term is not None, "Schema term should not be None" + + # Stored procedures support + procedures = db_connection.getinfo(sql_const.SQL_PROCEDURES.value) + print("Procedures = ",procedures) + assert procedures is not None, "Procedures support should not be None" + + except Exception as e: + pytest.fail(f"getinfo failed for catalog support info: {e}") + +def test_getinfo_transaction_support(db_connection): + """Test transaction support info types.""" + + try: + # Transaction support + txn_capable = db_connection.getinfo(sql_const.SQL_TXN_CAPABLE.value) + print("Transaction capable = ",txn_capable) + assert txn_capable is not None, "Transaction capability should not be None" + + # Default transaction isolation + default_txn_isolation = db_connection.getinfo(sql_const.SQL_DEFAULT_TXN_ISOLATION.value) + print("Default Transaction isolation = ",default_txn_isolation) + assert default_txn_isolation is not None, "Default transaction isolation should not be None" + + # Multiple active transactions support + multiple_txn = db_connection.getinfo(sql_const.SQL_MULTIPLE_ACTIVE_TXN.value) + print("Multiple transaction = ",multiple_txn) + assert multiple_txn is not None, "Multiple active transactions support should not be None" + + except Exception as e: + pytest.fail(f"getinfo failed for transaction support info: {e}") + +def test_getinfo_data_types(db_connection): + """Test data type support info types.""" + + try: + # Numeric functions + numeric_functions = db_connection.getinfo(sql_const.SQL_NUMERIC_FUNCTIONS.value) + assert isinstance(numeric_functions, int), "Numeric functions should be an integer" + + # String functions + string_functions = db_connection.getinfo(sql_const.SQL_STRING_FUNCTIONS.value) + assert isinstance(string_functions, int), "String functions should be an integer" + + # Date/time functions + datetime_functions = db_connection.getinfo(sql_const.SQL_DATETIME_FUNCTIONS.value) + assert isinstance(datetime_functions, int), "Datetime functions should be an integer" + + except Exception as e: + pytest.fail(f"getinfo failed for data type support info: {e}") + +def test_getinfo_invalid_constant(db_connection): + """Test getinfo behavior with invalid constants.""" + # Use a constant that doesn't exist in ODBC + non_existent_constant = 9999 + try: + result = db_connection.getinfo(non_existent_constant) + # If it doesn't raise an exception, it should return None or an empty value + assert result is None or result == 0 or result == "", "Invalid constant should return None/empty" + except Exception: + # It's also acceptable to raise an exception for invalid constants + pass + +def test_getinfo_type_consistency(db_connection): + """Test that getinfo returns consistent types for repeated calls.""" + + # Choose a few representative info types that don't depend on DBMS + info_types = [ + sql_const.SQL_DRIVER_NAME.value, + sql_const.SQL_MAX_COLUMN_NAME_LEN.value, + sql_const.SQL_TXN_CAPABLE.value, + sql_const.SQL_IDENTIFIER_QUOTE_CHAR.value + ] + + for info_type in info_types: + # Call getinfo twice with the same info type + result1 = db_connection.getinfo(info_type) + result2 = db_connection.getinfo(info_type) + + # Results should be consistent in type and value + assert type(result1) == type(result2), f"Type inconsistency for info type {info_type}" + assert result1 == result2, f"Value inconsistency for info type {info_type}" + +def test_getinfo_standard_types(db_connection): + """Test a representative set of standard ODBC info types.""" + + # Dictionary of common info types and their expected value types + # Avoid DBMS-specific info types + info_types = { + sql_const.SQL_ACCESSIBLE_TABLES.value: str, # "Y" or "N" + sql_const.SQL_DATA_SOURCE_NAME.value: str, # DSN + sql_const.SQL_TABLE_TERM.value: str, # Usually "table" + sql_const.SQL_PROCEDURES.value: str, # "Y" or "N" + sql_const.SQL_MAX_IDENTIFIER_LEN.value: int, # Max identifier length + sql_const.SQL_OUTER_JOINS.value: str, # "Y" or "N" + } + + for info_type, expected_type in info_types.items(): + try: + info_value = db_connection.getinfo(info_type) + + # Skip None values (unsupported by driver) + if info_value is None: + continue + + # Check type, allowing empty strings for string types + if expected_type == str: + assert isinstance(info_value, str), f"Info type {info_type} should return a string" + elif expected_type == int: + assert isinstance(info_value, int), f"Info type {info_type} should return an integer" + + except Exception as e: + # Log but don't fail - some drivers might not support all info types + print(f"Info type {info_type} failed: {e}") + +def test_connection_searchescape_basic(db_connection): + """Test the basic functionality of the searchescape property.""" + # Get the search escape character + escape_char = db_connection.searchescape + + # Verify it's not None + assert escape_char is not None, "Search escape character should not be None" + print(f"Search pattern escape character: '{escape_char}'") + + # Test property caching - calling it twice should return the same value + escape_char2 = db_connection.searchescape + assert escape_char == escape_char2, "Search escape character should be consistent" + +def test_connection_searchescape_with_percent(db_connection): + """Test using the searchescape property with percent wildcard.""" + escape_char = db_connection.searchescape + + # Skip test if we got a non-string or empty escape character + if not isinstance(escape_char, str) or not escape_char: + pytest.skip("No valid escape character available for testing") + + cursor = db_connection.cursor() + try: + # Create a temporary table with data containing % character + cursor.execute("CREATE TABLE #test_escape_percent (id INT, text VARCHAR(50))") + cursor.execute("INSERT INTO #test_escape_percent VALUES (1, 'abc%def')") + cursor.execute("INSERT INTO #test_escape_percent VALUES (2, 'abc_def')") + cursor.execute("INSERT INTO #test_escape_percent VALUES (3, 'abcdef')") + + # Use the escape character to find the exact % character + query = f"SELECT * FROM #test_escape_percent WHERE text LIKE 'abc{escape_char}%def' ESCAPE '{escape_char}'" + cursor.execute(query) + results = cursor.fetchall() + + # Should match only the row with the % character + assert len(results) == 1, f"Escaped LIKE query for % matched {len(results)} rows instead of 1" + if results: + assert 'abc%def' in results[0][1], "Escaped LIKE query did not match correct row" + + except Exception as e: + print(f"Note: LIKE escape test with % failed: {e}") + # Don't fail the test as some drivers might handle escaping differently + finally: + cursor.execute("DROP TABLE #test_escape_percent") + +def test_connection_searchescape_with_underscore(db_connection): + """Test using the searchescape property with underscore wildcard.""" + escape_char = db_connection.searchescape + + # Skip test if we got a non-string or empty escape character + if not isinstance(escape_char, str) or not escape_char: + pytest.skip("No valid escape character available for testing") + + cursor = db_connection.cursor() + try: + # Create a temporary table with data containing _ character + cursor.execute("CREATE TABLE #test_escape_underscore (id INT, text VARCHAR(50))") + cursor.execute("INSERT INTO #test_escape_underscore VALUES (1, 'abc_def')") + cursor.execute("INSERT INTO #test_escape_underscore VALUES (2, 'abcXdef')") # 'X' could match '_' + cursor.execute("INSERT INTO #test_escape_underscore VALUES (3, 'abcdef')") # No match + + # Use the escape character to find the exact _ character + query = f"SELECT * FROM #test_escape_underscore WHERE text LIKE 'abc{escape_char}_def' ESCAPE '{escape_char}'" + cursor.execute(query) + results = cursor.fetchall() + + # Should match only the row with the _ character + assert len(results) == 1, f"Escaped LIKE query for _ matched {len(results)} rows instead of 1" + if results: + assert 'abc_def' in results[0][1], "Escaped LIKE query did not match correct row" + + except Exception as e: + print(f"Note: LIKE escape test with _ failed: {e}") + # Don't fail the test as some drivers might handle escaping differently + finally: + cursor.execute("DROP TABLE #test_escape_underscore") + +def test_connection_searchescape_with_brackets(db_connection): + """Test using the searchescape property with bracket wildcards.""" + escape_char = db_connection.searchescape + + # Skip test if we got a non-string or empty escape character + if not isinstance(escape_char, str) or not escape_char: + pytest.skip("No valid escape character available for testing") + + cursor = db_connection.cursor() + try: + # Create a temporary table with data containing [ character + cursor.execute("CREATE TABLE #test_escape_brackets (id INT, text VARCHAR(50))") + cursor.execute("INSERT INTO #test_escape_brackets VALUES (1, 'abc[x]def')") + cursor.execute("INSERT INTO #test_escape_brackets VALUES (2, 'abcxdef')") + + # Use the escape character to find the exact [ character + # Note: This might not work on all drivers as bracket escaping varies + query = f"SELECT * FROM #test_escape_brackets WHERE text LIKE 'abc{escape_char}[x{escape_char}]def' ESCAPE '{escape_char}'" + cursor.execute(query) + results = cursor.fetchall() + + # Just check we got some kind of result without asserting specific behavior + print(f"Bracket escaping test returned {len(results)} rows") + + except Exception as e: + print(f"Note: LIKE escape test with brackets failed: {e}") + # Don't fail the test as bracket escaping varies significantly between drivers + finally: + cursor.execute("DROP TABLE #test_escape_brackets") + +def test_connection_searchescape_multiple_escapes(db_connection): + """Test using the searchescape property with multiple escape sequences.""" + escape_char = db_connection.searchescape + + # Skip test if we got a non-string or empty escape character + if not isinstance(escape_char, str) or not escape_char: + pytest.skip("No valid escape character available for testing") + + cursor = db_connection.cursor() + try: + # Create a temporary table with data containing multiple special chars + cursor.execute("CREATE TABLE #test_multiple_escapes (id INT, text VARCHAR(50))") + cursor.execute("INSERT INTO #test_multiple_escapes VALUES (1, 'abc%def_ghi')") + cursor.execute("INSERT INTO #test_multiple_escapes VALUES (2, 'abc%defXghi')") # Wouldn't match the pattern + cursor.execute("INSERT INTO #test_multiple_escapes VALUES (3, 'abcXdef_ghi')") # Wouldn't match the pattern + + # Use escape character for both % and _ + query = f""" + SELECT * FROM #test_multiple_escapes + WHERE text LIKE 'abc{escape_char}%def{escape_char}_ghi' ESCAPE '{escape_char}' + """ + cursor.execute(query) + results = cursor.fetchall() + + # Should match only the row with both % and _ + assert len(results) <= 1, f"Multiple escapes query matched {len(results)} rows instead of at most 1" + if len(results) == 1: + assert 'abc%def_ghi' in results[0][1], "Multiple escapes query matched incorrect row" + + except Exception as e: + print(f"Note: Multiple escapes test failed: {e}") + # Don't fail the test as escaping behavior varies + finally: + cursor.execute("DROP TABLE #test_multiple_escapes") + +def test_connection_searchescape_consistency(db_connection): + """Test that the searchescape property is cached and consistent.""" + # Call the property multiple times + escape1 = db_connection.searchescape + escape2 = db_connection.searchescape + escape3 = db_connection.searchescape + + # All calls should return the same value + assert escape1 == escape2 == escape3, "Searchescape property should be consistent" + + # Create a new connection and verify it returns the same escape character + # (assuming the same driver and connection settings) + if 'conn_str' in globals(): + try: + new_conn = connect(conn_str) + new_escape = new_conn.searchescape + assert new_escape == escape1, "Searchescape should be consistent across connections" + new_conn.close() + except Exception as e: + print(f"Note: New connection comparison failed: {e}") \ No newline at end of file diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 0d2fc2323..9b7276ab6 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -9953,6 +9953,248 @@ def test_columns_cleanup(cursor, db_connection): except Exception as e: pytest.fail(f"Test cleanup failed: {e}") +def test_lowercase_attribute(cursor, db_connection): + """Test that the lowercase attribute properly converts column names to lowercase""" + + # Store original value to restore after test + original_lowercase = mssql_python.lowercase + drop_cursor = None + + try: + # Create a test table with mixed-case column names + cursor.execute(""" + CREATE TABLE #pytest_lowercase_test ( + ID INT PRIMARY KEY, + UserName VARCHAR(50), + EMAIL_ADDRESS VARCHAR(100), + PhoneNumber VARCHAR(20) + ) + """) + db_connection.commit() + + # Insert test data + cursor.execute(""" + INSERT INTO #pytest_lowercase_test (ID, UserName, EMAIL_ADDRESS, PhoneNumber) + VALUES (1, 'JohnDoe', 'john@example.com', '555-1234') + """) + db_connection.commit() + + # First test with lowercase=False (default) + mssql_python.lowercase = False + cursor1 = db_connection.cursor() + cursor1.execute("SELECT * FROM #pytest_lowercase_test") + + # Description column names should preserve original case + column_names1 = [desc[0] for desc in cursor1.description] + assert "ID" in column_names1, "Column 'ID' should be present with original case" + assert "UserName" in column_names1, "Column 'UserName' should be present with original case" + + # Make sure to consume all results and close the cursor + cursor1.fetchall() + cursor1.close() + + # Now test with lowercase=True + mssql_python.lowercase = True + cursor2 = db_connection.cursor() + cursor2.execute("SELECT * FROM #pytest_lowercase_test") + + # Description column names should be lowercase + column_names2 = [desc[0] for desc in cursor2.description] + assert "id" in column_names2, "Column names should be lowercase when lowercase=True" + assert "username" in column_names2, "Column names should be lowercase when lowercase=True" + + # Make sure to consume all results and close the cursor + cursor2.fetchall() + cursor2.close() + + # Create a fresh cursor for cleanup + drop_cursor = db_connection.cursor() + + finally: + # Restore original value + mssql_python.lowercase = original_lowercase + + try: + # Use a separate cursor for cleanup + if drop_cursor: + drop_cursor.execute("DROP TABLE IF EXISTS #pytest_lowercase_test") + db_connection.commit() + drop_cursor.close() + except Exception as e: + print(f"Warning: Failed to drop test table: {e}") + +def test_decimal_separator_function(cursor, db_connection): + """Test decimal separator functionality with database operations""" + # Store original value to restore after test + original_separator = mssql_python.getDecimalSeparator() + + try: + # Create test table + cursor.execute(""" + CREATE TABLE #pytest_decimal_separator_test ( + id INT PRIMARY KEY, + decimal_value DECIMAL(10, 2) + ) + """) + db_connection.commit() + + # Insert test values with default separator (.) + test_value = decimal.Decimal('123.45') + cursor.execute(""" + INSERT INTO #pytest_decimal_separator_test (id, decimal_value) + VALUES (1, ?) + """, [test_value]) + db_connection.commit() + + # First test with default decimal separator (.) + cursor.execute("SELECT id, decimal_value FROM #pytest_decimal_separator_test") + row = cursor.fetchone() + default_str = str(row) + assert '123.45' in default_str, "Default separator not found in string representation" + + # Now change to comma separator and test string representation + mssql_python.setDecimalSeparator(',') + cursor.execute("SELECT id, decimal_value FROM #pytest_decimal_separator_test") + row = cursor.fetchone() + + # This should format the decimal with a comma in the string representation + comma_str = str(row) + assert '123,45' in comma_str, f"Expected comma in string representation but got: {comma_str}" + + finally: + # Restore original decimal separator + mssql_python.setDecimalSeparator(original_separator) + + # Cleanup + cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_separator_test") + db_connection.commit() + +def test_decimal_separator_basic_functionality(): + """Test basic decimal separator functionality without database operations""" + # Store original value to restore after test + original_separator = mssql_python.getDecimalSeparator() + + try: + # Test default value + assert mssql_python.getDecimalSeparator() == '.', "Default decimal separator should be '.'" + + # Test setting to comma + mssql_python.setDecimalSeparator(',') + assert mssql_python.getDecimalSeparator() == ',', "Decimal separator should be ',' after setting" + + # Test setting to other valid separators + mssql_python.setDecimalSeparator(':') + assert mssql_python.getDecimalSeparator() == ':', "Decimal separator should be ':' after setting" + + # Test invalid inputs + with pytest.raises(ValueError): + mssql_python.setDecimalSeparator('') # Empty string + + with pytest.raises(ValueError): + mssql_python.setDecimalSeparator('too_long') # More than one character + + with pytest.raises(ValueError): + mssql_python.setDecimalSeparator(123) # Not a string + + finally: + # Restore original separator + mssql_python.setDecimalSeparator(original_separator) + +def test_decimal_separator_with_multiple_values(cursor, db_connection): + """Test decimal separator with multiple different decimal values""" + original_separator = mssql_python.getDecimalSeparator() + + try: + # Create test table + cursor.execute(""" + CREATE TABLE #pytest_decimal_multi_test ( + id INT PRIMARY KEY, + positive_value DECIMAL(10, 2), + negative_value DECIMAL(10, 2), + zero_value DECIMAL(10, 2), + small_value DECIMAL(10, 4) + ) + """) + db_connection.commit() + + # Insert test data + cursor.execute(""" + INSERT INTO #pytest_decimal_multi_test VALUES (1, 123.45, -67.89, 0.00, 0.0001) + """) + db_connection.commit() + + # Test with default separator first + cursor.execute("SELECT * FROM #pytest_decimal_multi_test") + row = cursor.fetchone() + default_str = str(row) + assert '123.45' in default_str, "Default positive value formatting incorrect" + assert '-67.89' in default_str, "Default negative value formatting incorrect" + + # Change to comma separator + mssql_python.setDecimalSeparator(',') + cursor.execute("SELECT * FROM #pytest_decimal_multi_test") + row = cursor.fetchone() + comma_str = str(row) + + # Verify comma is used in all decimal values + assert '123,45' in comma_str, "Positive value not formatted with comma" + assert '-67,89' in comma_str, "Negative value not formatted with comma" + assert '0,00' in comma_str, "Zero value not formatted with comma" + assert '0,0001' in comma_str, "Small value not formatted with comma" + + finally: + # Restore original separator + mssql_python.setDecimalSeparator(original_separator) + + # Cleanup + cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_multi_test") + db_connection.commit() + +def test_decimal_separator_calculations(cursor, db_connection): + """Test that decimal separator doesn't affect calculations""" + original_separator = mssql_python.getDecimalSeparator() + + try: + # Create test table + cursor.execute(""" + CREATE TABLE #pytest_decimal_calc_test ( + id INT PRIMARY KEY, + value1 DECIMAL(10, 2), + value2 DECIMAL(10, 2) + ) + """) + db_connection.commit() + + # Insert test data + cursor.execute(""" + INSERT INTO #pytest_decimal_calc_test VALUES (1, 10.25, 5.75) + """) + db_connection.commit() + + # Test with default separator + cursor.execute("SELECT value1 + value2 AS sum_result FROM #pytest_decimal_calc_test") + row = cursor.fetchone() + assert row.sum_result == decimal.Decimal('16.00'), "Sum calculation incorrect with default separator" + + # Change to comma separator + mssql_python.setDecimalSeparator(',') + + # Calculations should still work correctly + cursor.execute("SELECT value1 + value2 AS sum_result FROM #pytest_decimal_calc_test") + row = cursor.fetchone() + assert row.sum_result == decimal.Decimal('16.00'), "Sum calculation affected by separator change" + + # But string representation should use comma + assert '16,00' in str(row), "Sum result not formatted with comma in string representation" + + finally: + # Restore original separator + mssql_python.setDecimalSeparator(original_separator) + + # Cleanup + cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_calc_test") + db_connection.commit() + def test_close(db_connection): """Test closing the cursor""" try: