From 8f691619b99a360db15a179a2e021d775568999f Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Wed, 20 Aug 2025 14:42:52 +0530 Subject: [PATCH 1/7] FEAT: Adding lowercase for global variable --- mssql_python/__init__.py | 26 ++++++++++---- mssql_python/cursor.py | 66 +++++++++++++++++++++-------------- mssql_python/row.py | 31 +++++++++++++---- tests/test_001_globals.py | 8 ++++- tests/test_004_cursor.py | 72 ++++++++++++++++++++++++++++++++++++++- 5 files changed, 163 insertions(+), 40 deletions(-) diff --git a/mssql_python/__init__.py b/mssql_python/__init__.py index 6bf957779..8f8635964 100644 --- a/mssql_python/__init__.py +++ b/mssql_python/__init__.py @@ -6,6 +6,26 @@ # Exceptions # https://www.python.org/dev/peps/pep-0249/#exceptions + +# GLOBALS +# Read-Only +apilevel = "2.0" +paramstyle = "qmark" +threadsafety = 1 + +class Settings: + def __init__(self): + self.lowercase = False + +# Create a global instance +_settings = Settings() + +def get_settings(): + return _settings + +lowercase = _settings.lowercase # Default is False + +# Import necessary modules from .exceptions import ( Warning, Error, @@ -47,12 +67,6 @@ # Constants from .constants import ConstantsDDBC -# GLOBALS -# Read-Only -apilevel = "2.0" -paramstyle = "qmark" -threadsafety = 1 - from .pooling import PoolingManager def pooling(max_size=100, idle_timeout=600, enabled=True): # """ diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index ed1bb70dc..912cb4a8c 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -17,7 +17,8 @@ from mssql_python.helpers import check_error, log from mssql_python import ddbc_bindings from mssql_python.exceptions import InterfaceError -from .row import Row +from mssql_python.row import Row +from mssql_python import get_settings class Cursor: @@ -73,6 +74,8 @@ def __init__(self, connection) -> None: # Is a list instead of a bool coz bools in Python are immutable. # Hence, we can't pass around bools by reference & modify them. # Therefore, it must be a list with exactly one bool element. + + self.lowercase = get_settings().lowercase def _is_unicode_string(self, param): """ @@ -480,26 +483,32 @@ def _create_parameter_types_list(self, parameter, param_info, parameters_list, i paraminfo.decimalDigits = decimal_digits return paraminfo - def _initialize_description(self): - """ - Initialize the description attribute using SQLDescribeCol. - """ - col_metadata = [] - ret = ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, col_metadata) - check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) - - self.description = [ - ( - col["ColumnName"], - self._map_data_type(col["DataType"]), - None, - col["ColumnSize"], - col["ColumnSize"], - col["DecimalDigits"], - col["Nullable"] == ddbc_sql_const.SQL_NULLABLE.value, - ) - for col in col_metadata - ] + def _initialize_description(self, column_metadata=None): + """Initialize the description attribute from column metadata.""" + if not column_metadata: + self.description = None + return + import mssql_python + + description = [] + for i, col in enumerate(column_metadata): + # Get column name - lowercase it if the lowercase flag is set + column_name = col["ColumnName"] + + if mssql_python.lowercase: + column_name = column_name.lower() + + # Add to description tuple (7 elements as per PEP-249) + description.append(( + column_name, # name + self._map_data_type(col["DataType"]), # type_code + None, # display_size + col["ColumnSize"], # internal_size + col["ColumnSize"], # precision - should match ColumnSize + col["DecimalDigits"], # scale + col["Nullable"] == ddbc_sql_const.SQL_NULLABLE.value, # null_ok + )) + self.description = description def _map_data_type(self, sql_type): """ @@ -611,7 +620,14 @@ def execute( self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt) # Initialize description after execution - self._initialize_description() + # After successful execution, initialize description if there are results + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + except Exception as e: + # If describe fails, it's likely there are no results (e.g., for INSERT) + self.description = None @staticmethod def _select_best_sample_value(column): @@ -727,7 +743,7 @@ def fetchone(self) -> Union[None, Row]: return None # Create and return a Row object - return Row(row_data, self.description) + return Row(self, self.description, row_data) def fetchmany(self, size: int = None) -> List[Row]: """ @@ -752,7 +768,7 @@ def fetchmany(self, size: int = None) -> List[Row]: ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size) # Convert raw data to Row objects - return [Row(row_data, self.description) for row_data in rows_data] + return [Row(self, self.description, row_data) for row_data in rows_data] def fetchall(self) -> List[Row]: """ @@ -768,7 +784,7 @@ def fetchall(self) -> List[Row]: ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data) # Convert raw data to Row objects - return [Row(row_data, self.description) for row_data in rows_data] + return [Row(self, self.description, row_data) for row_data in rows_data] def nextset(self) -> Union[bool, None]: """ diff --git a/mssql_python/row.py b/mssql_python/row.py index 2c88412de..0b1fd33ea 100644 --- a/mssql_python/row.py +++ b/mssql_python/row.py @@ -9,14 +9,17 @@ class Row: print(row.column_name) # Access by column name """ - def __init__(self, values, cursor_description): + def __init__(self, cursor, description, values, column_map=None): """ Initialize a Row object with values and cursor description. Args: + cursor: The cursor object + description: The cursor description containing column metadata values: List of values for this row - cursor_description: The cursor description containing column metadata + column_map: Optional pre-built column map (for optimization) """ + self._cursor = cursor self._values = values # TODO: ADO task - Optimize memory usage by sharing column map across rows @@ -26,10 +29,14 @@ def __init__(self, values, cursor_description): # 3. Remove cursor_description from Row objects entirely # Create mapping of column names to indices - self._column_map = {} - for i, desc in enumerate(cursor_description): - if desc and desc[0]: # Ensure column name exists - self._column_map[desc[0]] = i + # If column_map is not provided, build it from description + if column_map is None: + column_map = {} + for i, col_desc in enumerate(description): + col_name = col_desc[0] # Name is first item in description tuple + column_map[col_name] = i + + self._column_map = column_map def __getitem__(self, index): """Allow accessing by numeric index: row[0]""" @@ -37,9 +44,19 @@ def __getitem__(self, index): def __getattr__(self, name): """Allow accessing by column name as attribute: row.column_name""" + # Handle lowercase attribute access - if lowercase is enabled, + # try to match attribute names case-insensitively if name in self._column_map: return self._values[self._column_map[name]] - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + # If lowercase is enabled on the cursor, try case-insensitive lookup + if hasattr(self._cursor, 'lowercase') and self._cursor.lowercase: + name_lower = name.lower() + for col_name in self._column_map: + if col_name.lower() == name_lower: + return self._values[self._column_map[col_name]] + + raise AttributeError(f"Row has no attribute '{name}'") def __eq__(self, other): """ diff --git a/tests/test_001_globals.py b/tests/test_001_globals.py index f41a9a14f..fbee7ec5e 100644 --- a/tests/test_001_globals.py +++ b/tests/test_001_globals.py @@ -4,12 +4,13 @@ - test_apilevel: Check if apilevel has the expected value. - test_threadsafety: Check if threadsafety has the expected value. - test_paramstyle: Check if paramstyle has the expected value. +- test_lowercase: Check if lowercase has the expected value. """ import pytest # Import global variables from the repository -from mssql_python import apilevel, threadsafety, paramstyle +from mssql_python import apilevel, threadsafety, paramstyle, lowercase def test_apilevel(): # Check if apilevel has the expected value @@ -22,3 +23,8 @@ def test_threadsafety(): def test_paramstyle(): # Check if paramstyle has the expected value assert paramstyle == "qmark", "paramstyle should be 'qmark'" + +def test_lowercase(): + # Check if lowercase has the expected default value + assert lowercase is False, "lowercase should default to False" + diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 6a8c84281..728b27e23 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -12,6 +12,7 @@ from datetime import datetime, date, time import decimal from mssql_python import Connection +import mssql_python # Setup test table TEST_TABLE = """ @@ -1313,6 +1314,76 @@ def test_row_column_mapping(cursor, db_connection): cursor.execute("DROP TABLE #pytest_row_test") db_connection.commit() +def test_lowercase_attribute(cursor, db_connection): + """Test that the lowercase attribute properly converts column names to lowercase""" + + # Store original value to restore after test + original_lowercase = mssql_python.lowercase + drop_cursor = None + + try: + # Create a test table with mixed-case column names + cursor.execute(""" + CREATE TABLE #pytest_lowercase_test ( + ID INT PRIMARY KEY, + UserName VARCHAR(50), + EMAIL_ADDRESS VARCHAR(100), + PhoneNumber VARCHAR(20) + ) + """) + db_connection.commit() + + # Insert test data + cursor.execute(""" + INSERT INTO #pytest_lowercase_test (ID, UserName, EMAIL_ADDRESS, PhoneNumber) + VALUES (1, 'JohnDoe', 'john@example.com', '555-1234') + """) + db_connection.commit() + + # First test with lowercase=False (default) + mssql_python.lowercase = False + cursor1 = db_connection.cursor() + cursor1.execute("SELECT * FROM #pytest_lowercase_test") + + # Description column names should preserve original case + column_names1 = [desc[0] for desc in cursor1.description] + assert "ID" in column_names1, "Column 'ID' should be present with original case" + assert "UserName" in column_names1, "Column 'UserName' should be present with original case" + + # Make sure to consume all results and close the cursor + cursor1.fetchall() + cursor1.close() + + # Now test with lowercase=True + mssql_python.lowercase = True + cursor2 = db_connection.cursor() + cursor2.execute("SELECT * FROM #pytest_lowercase_test") + + # Description column names should be lowercase + column_names2 = [desc[0] for desc in cursor2.description] + assert "id" in column_names2, "Column names should be lowercase when lowercase=True" + assert "username" in column_names2, "Column names should be lowercase when lowercase=True" + + # Make sure to consume all results and close the cursor + cursor2.fetchall() + cursor2.close() + + # Create a fresh cursor for cleanup + drop_cursor = db_connection.cursor() + + finally: + # Restore original value + mssql_python.lowercase = original_lowercase + + try: + # Use a separate cursor for cleanup + if drop_cursor: + drop_cursor.execute("DROP TABLE IF EXISTS #pytest_lowercase_test") + db_connection.commit() + drop_cursor.close() + except Exception as e: + print(f"Warning: Failed to drop test table: {e}") + def test_close(db_connection): """Test closing the cursor""" try: @@ -1323,4 +1394,3 @@ def test_close(db_connection): pytest.fail(f"Cursor close test failed: {e}") finally: cursor = db_connection.cursor() - \ No newline at end of file From ee871ae315f400243e8870948e216551801ac49a Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Thu, 21 Aug 2025 12:16:12 +0530 Subject: [PATCH 2/7] FEAT: Adding getDecimalSeperator and setDecimalSeperator as global functions --- mssql_python/__init__.py | 38 +++++- mssql_python/pybind/ddbc_bindings.cpp | 44 +++++-- mssql_python/pybind/ddbc_bindings.h | 6 + mssql_python/row.py | 20 ++- tests/test_001_globals.py | 28 ++++- tests/test_004_cursor.py | 172 ++++++++++++++++++++++++++ 6 files changed, 294 insertions(+), 14 deletions(-) diff --git a/mssql_python/__init__.py b/mssql_python/__init__.py index 8f8635964..ec0f3b40a 100644 --- a/mssql_python/__init__.py +++ b/mssql_python/__init__.py @@ -16,8 +16,9 @@ class Settings: def __init__(self): self.lowercase = False + self.decimal_separator = "." -# Create a global instance +# Global settings instance _settings = Settings() def get_settings(): @@ -25,6 +26,40 @@ def get_settings(): lowercase = _settings.lowercase # Default is False +# Set the initial decimal separator in C++ +from .ddbc_bindings import DDBCSetDecimalSeparator +DDBCSetDecimalSeparator(_settings.decimal_separator) + +# New functions for decimal separator control +def setDecimalSeparator(separator): + """ + Sets the decimal separator character used when parsing NUMERIC/DECIMAL values + from the database, e.g. the "." in "1,234.56". + + The default is "." (period). This function overrides the default. + + Args: + separator (str): The character to use as decimal separator + """ + if not isinstance(separator, str) or len(separator) != 1: + raise ValueError("Decimal separator must be a single character string") + + _settings.decimal_separator = separator + + # Update the C++ side + from .ddbc_bindings import DDBCSetDecimalSeparator + DDBCSetDecimalSeparator(separator) + +def getDecimalSeparator(): + """ + Returns the decimal separator character used when parsing NUMERIC/DECIMAL values + from the database. + + Returns: + str: The current decimal separator character + """ + return _settings.decimal_separator + # Import necessary modules from .exceptions import ( Warning, @@ -85,4 +120,3 @@ def pooling(max_size=100, idle_timeout=600, enabled=True): PoolingManager.disable() else: PoolingManager.enable(max_size, idle_timeout) - \ No newline at end of file diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 1b37b8f0f..b5588a25d 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -1600,12 +1600,17 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p ret = SQLGetData_ptr(hStmt, i, SQL_C_CHAR, numericStr, sizeof(numericStr), &indicator); if (SQL_SUCCEEDED(ret)) { - try{ - // Convert numericStr to py::decimal.Decimal and append to row - row.append(py::module_::import("decimal").attr("Decimal")( - std::string(reinterpret_cast(numericStr), indicator))); + try { + // Use the original string with period for Python's Decimal constructor + std::string numStr(reinterpret_cast(numericStr), indicator); + + // Create Python Decimal object + py::object decimalObj = py::module_::import("decimal").attr("Decimal")(numStr); + + // Add to row + row.append(decimalObj); } catch (const py::error_already_set& e) { - // If the conversion fails, append None + // If conversion fails, append None LOG("Error converting to decimal: {}", e.what()); row.append(py::none()); } @@ -2085,11 +2090,20 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum case SQL_DECIMAL: case SQL_NUMERIC: { try { - // Convert numericStr to py::decimal.Decimal and append to row - row.append(py::module_::import("decimal").attr("Decimal")(std::string( - reinterpret_cast( - &buffers.charBuffers[col - 1][i * MAX_DIGITS_IN_NUMERIC]), - buffers.indicators[col - 1][i]))); + // Convert the string to use the current decimal separator + std::string numStr(reinterpret_cast( + &buffers.charBuffers[col - 1][i * MAX_DIGITS_IN_NUMERIC]), + buffers.indicators[col - 1][i]); + if (g_decimalSeparator != ".") { + // Replace the driver's decimal point with our configured separator + size_t pos = numStr.find('.'); + if (pos != std::string::npos) { + numStr.replace(pos, 1, g_decimalSeparator); + } + } + + // Convert to Python decimal + row.append(py::module_::import("decimal").attr("Decimal")(numStr)); } catch (const py::error_already_set& e) { // Handle the exception, e.g., log the error and append py::none() LOG("Error converting to decimal: {}", e.what()); @@ -2480,6 +2494,14 @@ void enable_pooling(int maxSize, int idleTimeout) { }); } +// Global decimal separator setting with default value +std::string g_decimalSeparator = "."; + +void DDBCSetDecimalSeparator(const std::string& separator) { + LOG("Setting decimal separator to: {}", separator); + g_decimalSeparator = separator; +} + // Architecture-specific defines #ifndef ARCHITECTURE #define ARCHITECTURE "win64" // Default to win64 if not defined during compilation @@ -2553,6 +2575,8 @@ PYBIND11_MODULE(ddbc_bindings, m) { m.def("DDBCSQLFetchAll", &FetchAll_wrap, "Fetch all rows from the result set"); m.def("DDBCSQLFreeHandle", &SQLFreeHandle_wrap, "Free a handle"); m.def("DDBCSQLCheckError", &SQLCheckError_Wrap, "Check for driver errors"); + m.def("DDBCSetDecimalSeparator", &DDBCSetDecimalSeparator, "Set the decimal separator character"); + // Add a version attribute m.attr("__version__") = "1.0.0"; diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 22bc524bd..d142276c6 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -271,3 +271,9 @@ inline std::wstring Utf8ToWString(const std::string& str) { return converter.from_bytes(str); #endif } + +// Global decimal separator setting +extern std::string g_decimalSeparator; + +// Function to set the decimal separator +void DDBCSetDecimalSeparator(const std::string& separator); diff --git a/mssql_python/row.py b/mssql_python/row.py index 0b1fd33ea..1f54e8c8c 100644 --- a/mssql_python/row.py +++ b/mssql_python/row.py @@ -79,7 +79,25 @@ def __iter__(self): def __str__(self): """Return string representation of the row""" - return str(tuple(self._values)) + from decimal import Decimal + from mssql_python import getDecimalSeparator + + parts = [] + for value in self: + if isinstance(value, Decimal): + # Apply custom decimal separator for display + sep = getDecimalSeparator() + if sep != '.' and value is not None: + s = str(value) + if '.' in s: + s = s.replace('.', sep) + parts.append(s) + else: + parts.append(str(value)) + else: + parts.append(repr(value)) + + return "(" + ", ".join(parts) + ")" def __repr__(self): """Return a detailed string representation for debugging""" diff --git a/tests/test_001_globals.py b/tests/test_001_globals.py index fbee7ec5e..779d46a81 100644 --- a/tests/test_001_globals.py +++ b/tests/test_001_globals.py @@ -10,7 +10,7 @@ import pytest # Import global variables from the repository -from mssql_python import apilevel, threadsafety, paramstyle, lowercase +from mssql_python import apilevel, threadsafety, paramstyle, lowercase, getDecimalSeparator, setDecimalSeparator def test_apilevel(): # Check if apilevel has the expected value @@ -28,3 +28,29 @@ def test_lowercase(): # Check if lowercase has the expected default value assert lowercase is False, "lowercase should default to False" +def test_decimal_separator(): + """Test decimal separator functionality""" + + # Check default value + assert getDecimalSeparator() == '.', "Default decimal separator should be '.'" + + try: + # Test setting a new value + setDecimalSeparator(',') + assert getDecimalSeparator() == ',', "Decimal separator should be ',' after setting" + + # Test invalid input + with pytest.raises(ValueError): + setDecimalSeparator('too long') + + with pytest.raises(ValueError): + setDecimalSeparator('') + + with pytest.raises(ValueError): + setDecimalSeparator(123) # Non-string input + + finally: + # Restore default value + setDecimalSeparator('.') + assert getDecimalSeparator() == '.', "Decimal separator should be restored to '.'" + diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 728b27e23..9a63e27f7 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -1384,6 +1384,178 @@ def test_lowercase_attribute(cursor, db_connection): except Exception as e: print(f"Warning: Failed to drop test table: {e}") +def test_decimal_separator_function(cursor, db_connection): + """Test decimal separator functionality with database operations""" + # Store original value to restore after test + original_separator = mssql_python.getDecimalSeparator() + + try: + # Create test table + cursor.execute(""" + CREATE TABLE #pytest_decimal_separator_test ( + id INT PRIMARY KEY, + decimal_value DECIMAL(10, 2) + ) + """) + db_connection.commit() + + # Insert test values with default separator (.) + test_value = decimal.Decimal('123.45') + cursor.execute(""" + INSERT INTO #pytest_decimal_separator_test (id, decimal_value) + VALUES (1, ?) + """, [test_value]) + db_connection.commit() + + # First test with default decimal separator (.) + cursor.execute("SELECT id, decimal_value FROM #pytest_decimal_separator_test") + row = cursor.fetchone() + default_str = str(row) + assert '123.45' in default_str, "Default separator not found in string representation" + + # Now change to comma separator and test string representation + mssql_python.setDecimalSeparator(',') + cursor.execute("SELECT id, decimal_value FROM #pytest_decimal_separator_test") + row = cursor.fetchone() + + # This should format the decimal with a comma in the string representation + comma_str = str(row) + assert '123,45' in comma_str, f"Expected comma in string representation but got: {comma_str}" + + finally: + # Restore original decimal separator + mssql_python.setDecimalSeparator(original_separator) + + # Cleanup + cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_separator_test") + db_connection.commit() + +def test_decimal_separator_basic_functionality(): + """Test basic decimal separator functionality without database operations""" + # Store original value to restore after test + original_separator = mssql_python.getDecimalSeparator() + + try: + # Test default value + assert mssql_python.getDecimalSeparator() == '.', "Default decimal separator should be '.'" + + # Test setting to comma + mssql_python.setDecimalSeparator(',') + assert mssql_python.getDecimalSeparator() == ',', "Decimal separator should be ',' after setting" + + # Test setting to other valid separators + mssql_python.setDecimalSeparator(':') + assert mssql_python.getDecimalSeparator() == ':', "Decimal separator should be ':' after setting" + + # Test invalid inputs + with pytest.raises(ValueError): + mssql_python.setDecimalSeparator('') # Empty string + + with pytest.raises(ValueError): + mssql_python.setDecimalSeparator('too_long') # More than one character + + with pytest.raises(ValueError): + mssql_python.setDecimalSeparator(123) # Not a string + + finally: + # Restore original separator + mssql_python.setDecimalSeparator(original_separator) + +def test_decimal_separator_with_multiple_values(cursor, db_connection): + """Test decimal separator with multiple different decimal values""" + original_separator = mssql_python.getDecimalSeparator() + + try: + # Create test table + cursor.execute(""" + CREATE TABLE #pytest_decimal_multi_test ( + id INT PRIMARY KEY, + positive_value DECIMAL(10, 2), + negative_value DECIMAL(10, 2), + zero_value DECIMAL(10, 2), + small_value DECIMAL(10, 4) + ) + """) + db_connection.commit() + + # Insert test data + cursor.execute(""" + INSERT INTO #pytest_decimal_multi_test VALUES (1, 123.45, -67.89, 0.00, 0.0001) + """) + db_connection.commit() + + # Test with default separator first + cursor.execute("SELECT * FROM #pytest_decimal_multi_test") + row = cursor.fetchone() + default_str = str(row) + assert '123.45' in default_str, "Default positive value formatting incorrect" + assert '-67.89' in default_str, "Default negative value formatting incorrect" + + # Change to comma separator + mssql_python.setDecimalSeparator(',') + cursor.execute("SELECT * FROM #pytest_decimal_multi_test") + row = cursor.fetchone() + comma_str = str(row) + + # Verify comma is used in all decimal values + assert '123,45' in comma_str, "Positive value not formatted with comma" + assert '-67,89' in comma_str, "Negative value not formatted with comma" + assert '0,00' in comma_str, "Zero value not formatted with comma" + assert '0,0001' in comma_str, "Small value not formatted with comma" + + finally: + # Restore original separator + mssql_python.setDecimalSeparator(original_separator) + + # Cleanup + cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_multi_test") + db_connection.commit() + +def test_decimal_separator_calculations(cursor, db_connection): + """Test that decimal separator doesn't affect calculations""" + original_separator = mssql_python.getDecimalSeparator() + + try: + # Create test table + cursor.execute(""" + CREATE TABLE #pytest_decimal_calc_test ( + id INT PRIMARY KEY, + value1 DECIMAL(10, 2), + value2 DECIMAL(10, 2) + ) + """) + db_connection.commit() + + # Insert test data + cursor.execute(""" + INSERT INTO #pytest_decimal_calc_test VALUES (1, 10.25, 5.75) + """) + db_connection.commit() + + # Test with default separator + cursor.execute("SELECT value1 + value2 AS sum_result FROM #pytest_decimal_calc_test") + row = cursor.fetchone() + assert row.sum_result == decimal.Decimal('16.00'), "Sum calculation incorrect with default separator" + + # Change to comma separator + mssql_python.setDecimalSeparator(',') + + # Calculations should still work correctly + cursor.execute("SELECT value1 + value2 AS sum_result FROM #pytest_decimal_calc_test") + row = cursor.fetchone() + assert row.sum_result == decimal.Decimal('16.00'), "Sum calculation affected by separator change" + + # But string representation should use comma + assert '16,00' in str(row), "Sum result not formatted with comma in string representation" + + finally: + # Restore original separator + mssql_python.setDecimalSeparator(original_separator) + + # Cleanup + cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_calc_test") + db_connection.commit() + def test_close(db_connection): """Test closing the cursor""" try: From e6c47d8daf6ca2d5ba9bb7741dbfa55221e07fe2 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Thu, 21 Aug 2025 13:48:21 +0530 Subject: [PATCH 3/7] FEAT: Adding getinfo --- mssql_python/connection.py | 24 ++++ mssql_python/pybind/connection/connection.cpp | 107 ++++++++++++++++++ mssql_python/pybind/connection/connection.h | 6 + mssql_python/pybind/ddbc_bindings.cpp | 7 +- mssql_python/pybind/ddbc_bindings.h | 2 + tests/test_003_connection.py | 84 ++++++++++++++ 6 files changed, 228 insertions(+), 2 deletions(-) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 12760df41..2a7af1464 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -185,6 +185,30 @@ def cursor(self) -> Cursor: cursor = Cursor(self) self._cursors.add(cursor) # Track the cursor return cursor + + def getinfo(self, info_type): + """ + Return general information about the driver and data source. + + Args: + info_type (int): The type of information to return. See the ODBC + SQLGetInfo documentation for the supported values. + + Returns: + The requested information. The type of the returned value depends + on the information requested. It will be a string, integer, or boolean. + + Raises: + DatabaseError: If there is an error retrieving the information. + InterfaceError: If the connection is closed. + """ + if self._closed: + raise InterfaceError( + driver_error="Cannot get info on closed connection", + ddbc_error="Cannot get info on closed connection", + ) + + return self._conn.get_info(info_type) def commit(self) -> None: """ diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index 9782efd22..9c1263f16 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -314,4 +314,111 @@ SqlHandlePtr ConnectionHandle::allocStatementHandle() { ThrowStdException("Connection object is not initialized"); } return _conn->allocStatementHandle(); +} + +py::object Connection::getInfo(SQLUSMALLINT infoType) const { + if (!_dbcHandle) { + ThrowStdException("Connection handle not allocated"); + } + + LOG("Getting connection info for type {}", infoType); + + // For string results - allocate a buffer + char charBuffer[1024] = {0}; + SQLSMALLINT stringLength = 0; + SQLRETURN ret; + + // First try to get the info as a string or binary data + ret = SQLGetInfo_ptr(_dbcHandle->get(), infoType, charBuffer, sizeof(charBuffer), &stringLength); + if (!SQL_SUCCEEDED(ret)) { + checkError(ret); + } + + // Determine return type based on the InfoType + // String types usually have InfoType > 10000 + if (infoType > 10000 || + infoType == SQL_DATA_SOURCE_NAME || + infoType == SQL_DBMS_NAME || + infoType == SQL_DBMS_VER || + infoType == SQL_DRIVER_NAME || + infoType == SQL_DRIVER_VER) { + // Return as string + return py::str(charBuffer); + } + else if (infoType == SQL_DRIVER_ODBC_VER || + infoType == SQL_SERVER_NAME) { + // Return as string + return py::str(charBuffer); + } + else { + // For numeric types, we need to interpret the buffer based on the expected return type + // Handle common numeric types + switch (infoType) { + // 16-bit unsigned integers + case SQL_MAX_CONCURRENT_ACTIVITIES: + case SQL_MAX_DRIVER_CONNECTIONS: + case SQL_ODBC_API_CONFORMANCE: + case SQL_ODBC_SQL_CONFORMANCE: + { + SQLUSMALLINT value = *reinterpret_cast(charBuffer); + return py::int_(value); + } + + // 32-bit unsigned integers + case SQL_ASYNC_MODE: + case SQL_GETDATA_EXTENSIONS: + case SQL_MAX_ASYNC_CONCURRENT_STATEMENTS: + case SQL_MAX_COLUMNS_IN_GROUP_BY: + case SQL_MAX_COLUMNS_IN_ORDER_BY: + case SQL_MAX_COLUMNS_IN_SELECT: + case SQL_MAX_COLUMNS_IN_TABLE: + case SQL_MAX_ROW_SIZE: + case SQL_MAX_TABLES_IN_SELECT: + case SQL_MAX_USER_NAME_LEN: + case SQL_NUMERIC_FUNCTIONS: + case SQL_STRING_FUNCTIONS: + case SQL_SYSTEM_FUNCTIONS: + case SQL_TIMEDATE_FUNCTIONS: + { + SQLUINTEGER value = *reinterpret_cast(charBuffer); + return py::int_(value); + } + + // Boolean flags (32-bit mask) + case SQL_AGGREGATE_FUNCTIONS: + case SQL_ALTER_TABLE: + case SQL_CATALOG_USAGE: + case SQL_DATETIME_LITERALS: + case SQL_INDEX_KEYWORDS: + case SQL_INSERT_STATEMENT: + case SQL_SCHEMA_USAGE: + case SQL_SQL_CONFORMANCE: + case SQL_SQL92_DATETIME_FUNCTIONS: + case SQL_SQL92_NUMERIC_VALUE_FUNCTIONS: + case SQL_SQL92_PREDICATES: + case SQL_SQL92_RELATIONAL_JOIN_OPERATORS: + case SQL_SQL92_STRING_FUNCTIONS: + case SQL_STATIC_CURSOR_ATTRIBUTES1: + case SQL_STATIC_CURSOR_ATTRIBUTES2: + { + SQLUINTEGER value = *reinterpret_cast(charBuffer); + return py::int_(value); + } + + // Handle any other types as integers + default: + SQLUINTEGER value = *reinterpret_cast(charBuffer); + return py::int_(value); + } + } + + // Default return in case nothing matched + return py::none(); +} + +py::object ConnectionHandle::getInfo(SQLUSMALLINT infoType) const { + if (!_conn) { + ThrowStdException("Connection object is not initialized"); + } + return _conn->getInfo(infoType); } \ No newline at end of file diff --git a/mssql_python/pybind/connection/connection.h b/mssql_python/pybind/connection/connection.h index 6129125e1..66dd58952 100644 --- a/mssql_python/pybind/connection/connection.h +++ b/mssql_python/pybind/connection/connection.h @@ -42,6 +42,9 @@ class Connection { // Allocate a new statement handle on this connection. SqlHandlePtr allocStatementHandle(); + // Get information about the driver and data source + py::object getInfo(SQLUSMALLINT infoType) const; + private: void allocateDbcHandle(); void checkError(SQLRETURN ret) const; @@ -67,6 +70,9 @@ class ConnectionHandle { bool getAutocommit() const; SqlHandlePtr allocStatementHandle(); + // Get information about the driver and data source + py::object getInfo(SQLUSMALLINT infoType) const; + private: std::shared_ptr _conn; bool _usePool; diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index b5588a25d..142b3c2f2 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -123,6 +123,7 @@ SQLBindColFunc SQLBindCol_ptr = nullptr; SQLDescribeColFunc SQLDescribeCol_ptr = nullptr; SQLMoreResultsFunc SQLMoreResults_ptr = nullptr; SQLColAttributeFunc SQLColAttribute_ptr = nullptr; +SQLGetInfoFunc SQLGetInfo_ptr = nullptr; // Transaction APIs SQLEndTranFunc SQLEndTran_ptr = nullptr; @@ -779,6 +780,7 @@ DriverHandle LoadDriverOrThrowException() { SQLDescribeCol_ptr = GetFunctionPointer(handle, "SQLDescribeColW"); SQLMoreResults_ptr = GetFunctionPointer(handle, "SQLMoreResults"); SQLColAttribute_ptr = GetFunctionPointer(handle, "SQLColAttributeW"); + SQLGetInfo_ptr = GetFunctionPointer(handle, "SQLGetInfoW"); SQLEndTran_ptr = GetFunctionPointer(handle, "SQLEndTran"); SQLDisconnect_ptr = GetFunctionPointer(handle, "SQLDisconnect"); @@ -796,7 +798,7 @@ DriverHandle LoadDriverOrThrowException() { SQLGetData_ptr && SQLNumResultCols_ptr && SQLBindCol_ptr && SQLDescribeCol_ptr && SQLMoreResults_ptr && SQLColAttribute_ptr && SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr && - SQLFreeStmt_ptr && SQLGetDiagRec_ptr; + SQLFreeStmt_ptr && SQLGetDiagRec_ptr && SQLGetInfo_ptr; if (!success) { ThrowStdException("Failed to load required function pointers from driver."); @@ -2554,7 +2556,8 @@ PYBIND11_MODULE(ddbc_bindings, m) { .def("rollback", &ConnectionHandle::rollback, "Rollback the current transaction") .def("set_autocommit", &ConnectionHandle::setAutocommit) .def("get_autocommit", &ConnectionHandle::getAutocommit) - .def("alloc_statement_handle", &ConnectionHandle::allocStatementHandle); + .def("alloc_statement_handle", &ConnectionHandle::allocStatementHandle) + .def("get_info", &ConnectionHandle::getInfo, py::arg("info_type")); m.def("enable_pooling", &enable_pooling, "Enable global connection pooling"); m.def("close_pooling", []() {ConnectionPoolManager::getInstance().closePools();}); m.def("DDBCSQLExecDirect", &SQLExecDirect_wrap, "Execute a SQL query directly"); diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index d142276c6..21ba959c8 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -105,6 +105,7 @@ typedef SQLRETURN (SQL_API* SQLDescribeColFunc)(SQLHSTMT, SQLUSMALLINT, SQLWCHAR typedef SQLRETURN (SQL_API* SQLMoreResultsFunc)(SQLHSTMT); typedef SQLRETURN (SQL_API* SQLColAttributeFunc)(SQLHSTMT, SQLUSMALLINT, SQLUSMALLINT, SQLPOINTER, SQLSMALLINT, SQLSMALLINT*, SQLPOINTER); +typedef SQLRETURN (SQL_API* SQLGetInfoFunc)(SQLHDBC, SQLUSMALLINT, SQLPOINTER, SQLSMALLINT, SQLSMALLINT*); // Transaction APIs typedef SQLRETURN (SQL_API* SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT); @@ -148,6 +149,7 @@ extern SQLBindColFunc SQLBindCol_ptr; extern SQLDescribeColFunc SQLDescribeCol_ptr; extern SQLMoreResultsFunc SQLMoreResults_ptr; extern SQLColAttributeFunc SQLColAttribute_ptr; +extern SQLGetInfoFunc SQLGetInfo_ptr; // Transaction APIs extern SQLEndTranFunc SQLEndTran_ptr; diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index 51fce818e..79bd1d2fc 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -485,3 +485,87 @@ def test_connection_pooling_basic(conn_str): conn1.close() conn2.close() + +def test_getinfo_basic(db_connection): + """Test that getinfo() can retrieve basic driver and data source information.""" + # Test SQL_DATA_SOURCE_NAME + data_source = db_connection.getinfo(1) # SQL_DATA_SOURCE_NAME + assert data_source is not None, "Failed to retrieve data source name" + + # # Test SQL_DBMS_NAME + # dbms_name = db_connection.getinfo(17) # SQL_DBMS_NAME + # assert dbms_name is not None, "Failed to retrieve DBMS name" + # assert "SQL Server" in dbms_name, "DBMS name should contain 'SQL Server'" + + # Test SQL_DRIVER_NAME + driver_name = db_connection.getinfo(6) # SQL_DRIVER_NAME + assert driver_name is not None, "Failed to retrieve driver name" + assert "ODBC" in driver_name, "Driver name should contain 'ODBC'" + +def test_getinfo_return_types(db_connection): + """Test that getinfo() returns appropriate data types for different info types.""" + # String type + dbms_ver = db_connection.getinfo(18) # SQL_DBMS_VER + assert isinstance(dbms_ver, str), "DBMS version should be a string" + + # Integer type + max_columns = db_connection.getinfo(30) # SQL_MAX_COLUMNS_IN_TABLE + assert isinstance(max_columns, int), "MAX_COLUMNS_IN_TABLE should be an integer" + + # Another integer type + max_tables = db_connection.getinfo(106) # SQL_MAX_TABLES_IN_SELECT + assert isinstance(max_tables, int), "MAX_TABLES_IN_SELECT should be an integer" + +def test_getinfo_closed_connection(conn_str): + """Test that getinfo() raises an exception when called on a closed connection.""" + from mssql_python import connect + + # Create and close a connection + conn = connect(conn_str) + conn.close() + + # Calling getinfo() on a closed connection should raise an exception + with pytest.raises(InterfaceError) as excinfo: + conn.getinfo(1) # SQL_DATA_SOURCE_NAME + + assert "closed connection" in str(excinfo.value).lower(), "Exception message should mention closed connection" + +def test_getinfo_invalid_type(db_connection): + """Test that getinfo() handles invalid info types gracefully.""" + # Using a very large number that's unlikely to be a valid info type + with pytest.raises(Exception): + db_connection.getinfo(999999) + +def test_getinfo_driver_version(db_connection): + """Test that getinfo() can retrieve the driver version.""" + driver_ver = db_connection.getinfo(7) # SQL_DRIVER_VER + assert driver_ver is not None, "Failed to retrieve driver version" + print(driver_ver) + + # Driver version should have a pattern like "nn.nn.nnnn.nn" + import re + assert re.match(r"\d+\.\d+(\.\d+)*", driver_ver), f"Driver version '{driver_ver}' not in expected format" + +def test_getinfo_odbc_version(db_connection): + """Test that getinfo() can retrieve the ODBC version.""" + odbc_ver = db_connection.getinfo(10) # SQL_DRIVER_ODBC_VER + assert odbc_ver is not None, "Failed to retrieve ODBC version" + print(odbc_ver) + + # ODBC version should have a pattern like "nn.nn" or "nn.nn.nnnn" + import re + assert re.match(r"\d+\.\d+(\.\d+)*", odbc_ver), f"ODBC version '{odbc_ver}' not in expected format" + +def test_getinfo_numeric_constants(db_connection): + """Test that getinfo() properly returns numeric constants.""" + # SQL_MAX_CONCURRENT_ACTIVITIES should be a reasonable value > 0 + max_activities = db_connection.getinfo(1) # SQL_MAX_CONCURRENT_ACTIVITIES + assert isinstance(max_activities, (int, str)), "MAX_CONCURRENT_ACTIVITIES should be numeric or string" + + if isinstance(max_activities, int): + assert max_activities >= 1, "MAX_CONCURRENT_ACTIVITIES should be >= 1" + + # SQL_TXN_CAPABLE should indicate transaction capability + txn_capable = db_connection.getinfo(46) # SQL_TXN_CAPABLE + assert isinstance(txn_capable, int), "TXN_CAPABLE should be an integer" + assert txn_capable > 0, "Driver should support transactions" \ No newline at end of file From f6b328bbf084ea02d44b6dbc0d7badd6b79f43ef Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Fri, 22 Aug 2025 17:45:52 +0530 Subject: [PATCH 4/7] Changing testcases --- mssql_python/constants.py | 138 ++++++++++++++++++ tests/test_003_connection.py | 272 ++++++++++++++++++++++++++--------- 2 files changed, 339 insertions(+), 71 deletions(-) diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 81e60d37e..e75b74443 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -117,6 +117,144 @@ class ConstantsDDBC(Enum): SQL_NULLABLE = 1 SQL_MAX_NUMERIC_LEN = 16 +class GetInfoConstants(Enum): + """ + These constants are used with various methods like getinfo(). + """ + + # Driver and database information + SQL_DRIVER_NAME = 6 + SQL_DRIVER_VER = 7 + SQL_DRIVER_ODBC_VER = 77 + SQL_DRIVER_HLIB = 76 + SQL_DRIVER_HENV = 75 + SQL_DRIVER_HDBC = 74 + SQL_DATA_SOURCE_NAME = 2 + SQL_DATABASE_NAME = 16 + SQL_SERVER_NAME = 13 + SQL_USER_NAME = 47 + + # SQL conformance and support + SQL_SQL_CONFORMANCE = 118 + SQL_KEYWORDS = 89 + SQL_IDENTIFIER_CASE = 28 + SQL_IDENTIFIER_QUOTE_CHAR = 29 + SQL_SPECIAL_CHARACTERS = 94 + SQL_SQL92_ENTRY_SQL = 127 + SQL_SQL92_INTERMEDIATE_SQL = 128 + SQL_SQL92_FULL_SQL = 129 + SQL_SUBQUERIES = 95 + SQL_EXPRESSIONS_IN_ORDERBY = 27 + SQL_CORRELATION_NAME = 74 + SQL_SEARCH_PATTERN_ESCAPE = 14 + + # Catalog and schema support + SQL_CATALOG_TERM = 42 + SQL_CATALOG_NAME_SEPARATOR = 41 + SQL_SCHEMA_TERM = 39 + SQL_TABLE_TERM = 45 + SQL_PROCEDURES = 21 + SQL_ACCESSIBLE_TABLES = 19 + SQL_ACCESSIBLE_PROCEDURES = 20 + SQL_CATALOG_NAME = 10002 + SQL_CATALOG_USAGE = 92 + SQL_SCHEMA_USAGE = 91 + SQL_COLUMN_ALIAS = 87 + SQL_DESCRIBE_PARAMETER = 10002 + + # Transaction support + SQL_TXN_CAPABLE = 46 + SQL_TXN_ISOLATION_OPTION = 72 + SQL_DEFAULT_TXN_ISOLATION = 26 + SQL_MULTIPLE_ACTIVE_TXN = 37 + SQL_TXN_ISOLATION_LEVEL = 108 + + # Data type support + SQL_NUMERIC_FUNCTIONS = 49 + SQL_STRING_FUNCTIONS = 50 + SQL_DATETIME_FUNCTIONS = 51 + SQL_SYSTEM_FUNCTIONS = 58 + SQL_CONVERT_FUNCTIONS = 48 + SQL_LIKE_ESCAPE_CLAUSE = 113 + + # Numeric limits + SQL_MAX_COLUMN_NAME_LEN = 30 + SQL_MAX_TABLE_NAME_LEN = 35 + SQL_MAX_SCHEMA_NAME_LEN = 32 + SQL_MAX_CATALOG_NAME_LEN = 34 + SQL_MAX_IDENTIFIER_LEN = 10005 + SQL_MAX_STATEMENT_LEN = 105 + SQL_MAX_CHAR_LITERAL_LEN = 108 + SQL_MAX_BINARY_LITERAL_LEN = 112 + SQL_MAX_COLUMNS_IN_TABLE = 101 + SQL_MAX_COLUMNS_IN_SELECT = 100 + SQL_MAX_COLUMNS_IN_GROUP_BY = 97 + SQL_MAX_COLUMNS_IN_ORDER_BY = 99 + SQL_MAX_COLUMNS_IN_INDEX = 98 + SQL_MAX_TABLES_IN_SELECT = 106 + SQL_MAX_CONCURRENT_ACTIVITIES = 1 + SQL_MAX_DRIVER_CONNECTIONS = 0 + SQL_MAX_ROW_SIZE = 104 + SQL_MAX_USER_NAME_LEN = 107 + + # Connection attributes + SQL_ACTIVE_CONNECTIONS = 0 + SQL_ACTIVE_STATEMENTS = 1 + SQL_DATA_SOURCE_READ_ONLY = 25 + SQL_NEED_LONG_DATA_LEN = 111 + SQL_GETDATA_EXTENSIONS = 81 + + # Result set and cursor attributes + SQL_CURSOR_COMMIT_BEHAVIOR = 23 + SQL_CURSOR_ROLLBACK_BEHAVIOR = 24 + SQL_CURSOR_SENSITIVITY = 10001 + SQL_BOOKMARK_PERSISTENCE = 82 + SQL_DYNAMIC_CURSOR_ATTRIBUTES1 = 144 + SQL_DYNAMIC_CURSOR_ATTRIBUTES2 = 145 + SQL_FORWARD_ONLY_CURSOR_ATTRIBUTES1 = 146 + SQL_FORWARD_ONLY_CURSOR_ATTRIBUTES2 = 147 + SQL_STATIC_CURSOR_ATTRIBUTES1 = 150 + SQL_STATIC_CURSOR_ATTRIBUTES2 = 151 + SQL_KEYSET_CURSOR_ATTRIBUTES1 = 148 + SQL_KEYSET_CURSOR_ATTRIBUTES2 = 149 + SQL_SCROLL_OPTIONS = 44 + SQL_SCROLL_CONCURRENCY = 43 + SQL_FETCH_DIRECTION = 8 + SQL_ROWSET_SIZE = 9 + SQL_CONCURRENCY = 7 + SQL_ROW_NUMBER = 14 + SQL_STATIC_SENSITIVITY = 83 + SQL_BATCH_SUPPORT = 121 + SQL_BATCH_ROW_COUNT = 120 + SQL_PARAM_ARRAY_ROW_COUNTS = 153 + SQL_PARAM_ARRAY_SELECTS = 154 + + # Positioned statement support + SQL_POSITIONED_STATEMENTS = 80 + + # Other constants + SQL_GROUP_BY = 88 + SQL_OJ_CAPABILITIES = 65 + SQL_ORDER_BY_COLUMNS_IN_SELECT = 90 + SQL_OUTER_JOINS = 38 + SQL_QUOTED_IDENTIFIER_CASE = 93 + SQL_CONCAT_NULL_BEHAVIOR = 22 + SQL_NULL_COLLATION = 85 + SQL_ALTER_TABLE = 86 + SQL_UNION = 96 + SQL_DDL_INDEX = 170 + SQL_MULT_RESULT_SETS = 36 + SQL_OWNER_USAGE = 91 + SQL_QUALIFIER_USAGE = 92 + SQL_TIMEDATE_ADD_INTERVALS = 109 + SQL_TIMEDATE_DIFF_INTERVALS = 110 + + # Return values for some getinfo functions + SQL_IC_UPPER = 1 + SQL_IC_LOWER = 2 + SQL_IC_SENSITIVE = 3 + SQL_IC_MIXED = 4 + class AuthType(Enum): """Constants for authentication types""" INTERACTIVE = "activedirectoryinteractive" diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index 79bd1d2fc..8ddb33ef0 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -23,6 +23,7 @@ import time from mssql_python import Connection, connect, pooling import threading +from mssql_python.constants import GetInfoConstants as sql_const def drop_table_if_exists(cursor, table_name): """Drop the table if it exists""" @@ -486,86 +487,215 @@ def test_connection_pooling_basic(conn_str): conn1.close() conn2.close() -def test_getinfo_basic(db_connection): - """Test that getinfo() can retrieve basic driver and data source information.""" - # Test SQL_DATA_SOURCE_NAME - data_source = db_connection.getinfo(1) # SQL_DATA_SOURCE_NAME - assert data_source is not None, "Failed to retrieve data source name" +def test_getinfo_basic_driver_info(db_connection): + """Test basic driver information info types.""" - # # Test SQL_DBMS_NAME - # dbms_name = db_connection.getinfo(17) # SQL_DBMS_NAME - # assert dbms_name is not None, "Failed to retrieve DBMS name" - # assert "SQL Server" in dbms_name, "DBMS name should contain 'SQL Server'" - - # Test SQL_DRIVER_NAME - driver_name = db_connection.getinfo(6) # SQL_DRIVER_NAME - assert driver_name is not None, "Failed to retrieve driver name" - assert "ODBC" in driver_name, "Driver name should contain 'ODBC'" - -def test_getinfo_return_types(db_connection): - """Test that getinfo() returns appropriate data types for different info types.""" - # String type - dbms_ver = db_connection.getinfo(18) # SQL_DBMS_VER - assert isinstance(dbms_ver, str), "DBMS version should be a string" + try: + # Driver name should be available + driver_name = db_connection.getinfo(sql_const.SQL_DRIVER_NAME.value) + print("Driver Name = ",driver_name) + assert driver_name is not None, "Driver name should not be None" + + # Driver version should be available + driver_ver = db_connection.getinfo(sql_const.SQL_DRIVER_VER.value) + print("Driver Version = ",driver_ver) + assert driver_ver is not None, "Driver version should not be None" + + # Data source name should be available + dsn = db_connection.getinfo(sql_const.SQL_DATA_SOURCE_NAME.value) + print("Data source name = ",dsn) + assert dsn is not None, "Data source name should not be None" + + # Server name should be available (might be empty in some configurations) + server_name = db_connection.getinfo(sql_const.SQL_SERVER_NAME.value) + print("Server Name = ",server_name) + assert server_name is not None, "Server name should not be None" + + # User name should be available (might be empty if using integrated auth) + user_name = db_connection.getinfo(sql_const.SQL_USER_NAME.value) + print("User Name = ",user_name) + assert user_name is not None, "User name should not be None" + + except Exception as e: + pytest.fail(f"getinfo failed for basic driver info: {e}") + +def test_getinfo_sql_support(db_connection): + """Test SQL support and conformance info types.""" - # Integer type - max_columns = db_connection.getinfo(30) # SQL_MAX_COLUMNS_IN_TABLE - assert isinstance(max_columns, int), "MAX_COLUMNS_IN_TABLE should be an integer" + try: + # SQL conformance level + sql_conformance = db_connection.getinfo(sql_const.SQL_SQL_CONFORMANCE.value) + print("SQL Conformance = ",sql_conformance) + assert sql_conformance is not None, "SQL conformance should not be None" + + # Keywords - may return a very long string + keywords = db_connection.getinfo(sql_const.SQL_KEYWORDS.value) + print("Keywords = ",keywords) + assert keywords is not None, "SQL keywords should not be None" + + # Identifier quote character + quote_char = db_connection.getinfo(sql_const.SQL_IDENTIFIER_QUOTE_CHAR.value) + print(f"Identifier quote char: '{quote_char}'") + assert quote_char is not None, "Identifier quote char should not be None" + + except Exception as e: + pytest.fail(f"getinfo failed for SQL support info: {e}") + +def test_getinfo_numeric_limits(db_connection): + """Test numeric limitation info types.""" - # Another integer type - max_tables = db_connection.getinfo(106) # SQL_MAX_TABLES_IN_SELECT - assert isinstance(max_tables, int), "MAX_TABLES_IN_SELECT should be an integer" + try: + # Max column name length - should be a positive integer + max_col_name_len = db_connection.getinfo(sql_const.SQL_MAX_COLUMN_NAME_LEN.value) + assert isinstance(max_col_name_len, int), "Max column name length should be an integer" + assert max_col_name_len >= 0, "Max column name length should be non-negative" + + # Max table name length + max_table_name_len = db_connection.getinfo(sql_const.SQL_MAX_TABLE_NAME_LEN.value) + assert isinstance(max_table_name_len, int), "Max table name length should be an integer" + assert max_table_name_len >= 0, "Max table name length should be non-negative" + + # Max statement length - may return 0 for "unlimited" + max_statement_len = db_connection.getinfo(sql_const.SQL_MAX_STATEMENT_LEN.value) + assert isinstance(max_statement_len, int), "Max statement length should be an integer" + assert max_statement_len >= 0, "Max statement length should be non-negative" + + # Max connections - may return 0 for "unlimited" + max_connections = db_connection.getinfo(sql_const.SQL_MAX_DRIVER_CONNECTIONS.value) + assert isinstance(max_connections, int), "Max connections should be an integer" + assert max_connections >= 0, "Max connections should be non-negative" + + except Exception as e: + pytest.fail(f"getinfo failed for numeric limits info: {e}") -def test_getinfo_closed_connection(conn_str): - """Test that getinfo() raises an exception when called on a closed connection.""" - from mssql_python import connect +def test_getinfo_catalog_support(db_connection): + """Test catalog support info types.""" - # Create and close a connection - conn = connect(conn_str) - conn.close() + try: + # Catalog support for tables + catalog_term = db_connection.getinfo(sql_const.SQL_CATALOG_TERM.value) + print("Catalof term = ",catalog_term) + assert catalog_term is not None, "Catalog term should not be None" + + # Catalog name separator + catalog_separator = db_connection.getinfo(sql_const.SQL_CATALOG_NAME_SEPARATOR.value) + print(f"Catalog name separator: '{catalog_separator}'") + assert catalog_separator is not None, "Catalog separator should not be None" + + # Schema term + schema_term = db_connection.getinfo(sql_const.SQL_SCHEMA_TERM.value) + print("Schema term = ",schema_term) + assert schema_term is not None, "Schema term should not be None" + + # Stored procedures support + procedures = db_connection.getinfo(sql_const.SQL_PROCEDURES.value) + print("Procedures = ",procedures) + assert procedures is not None, "Procedures support should not be None" + + except Exception as e: + pytest.fail(f"getinfo failed for catalog support info: {e}") + +def test_getinfo_transaction_support(db_connection): + """Test transaction support info types.""" - # Calling getinfo() on a closed connection should raise an exception - with pytest.raises(InterfaceError) as excinfo: - conn.getinfo(1) # SQL_DATA_SOURCE_NAME + try: + # Transaction support + txn_capable = db_connection.getinfo(sql_const.SQL_TXN_CAPABLE.value) + print("Transaction capable = ",txn_capable) + assert txn_capable is not None, "Transaction capability should not be None" + + # Default transaction isolation + default_txn_isolation = db_connection.getinfo(sql_const.SQL_DEFAULT_TXN_ISOLATION.value) + print("Default Transaction isolation = ",default_txn_isolation) + assert default_txn_isolation is not None, "Default transaction isolation should not be None" + + # Multiple active transactions support + multiple_txn = db_connection.getinfo(sql_const.SQL_MULTIPLE_ACTIVE_TXN.value) + print("Multiple transaction = ",multiple_txn) + assert multiple_txn is not None, "Multiple active transactions support should not be None" + + except Exception as e: + pytest.fail(f"getinfo failed for transaction support info: {e}") + +def test_getinfo_data_types(db_connection): + """Test data type support info types.""" - assert "closed connection" in str(excinfo.value).lower(), "Exception message should mention closed connection" + try: + # Numeric functions + numeric_functions = db_connection.getinfo(sql_const.SQL_NUMERIC_FUNCTIONS.value) + assert isinstance(numeric_functions, int), "Numeric functions should be an integer" + + # String functions + string_functions = db_connection.getinfo(sql_const.SQL_STRING_FUNCTIONS.value) + assert isinstance(string_functions, int), "String functions should be an integer" + + # Date/time functions + datetime_functions = db_connection.getinfo(sql_const.SQL_DATETIME_FUNCTIONS.value) + assert isinstance(datetime_functions, int), "Datetime functions should be an integer" + + except Exception as e: + pytest.fail(f"getinfo failed for data type support info: {e}") -def test_getinfo_invalid_type(db_connection): - """Test that getinfo() handles invalid info types gracefully.""" - # Using a very large number that's unlikely to be a valid info type - with pytest.raises(Exception): - db_connection.getinfo(999999) +def test_getinfo_invalid_constant(db_connection): + """Test getinfo behavior with invalid constants.""" + # Use a constant that doesn't exist in ODBC + non_existent_constant = 9999 + try: + result = db_connection.getinfo(non_existent_constant) + # If it doesn't raise an exception, it should return None or an empty value + assert result is None or result == 0 or result == "", "Invalid constant should return None/empty" + except Exception: + # It's also acceptable to raise an exception for invalid constants + pass -def test_getinfo_driver_version(db_connection): - """Test that getinfo() can retrieve the driver version.""" - driver_ver = db_connection.getinfo(7) # SQL_DRIVER_VER - assert driver_ver is not None, "Failed to retrieve driver version" - print(driver_ver) - - # Driver version should have a pattern like "nn.nn.nnnn.nn" - import re - assert re.match(r"\d+\.\d+(\.\d+)*", driver_ver), f"Driver version '{driver_ver}' not in expected format" - -def test_getinfo_odbc_version(db_connection): - """Test that getinfo() can retrieve the ODBC version.""" - odbc_ver = db_connection.getinfo(10) # SQL_DRIVER_ODBC_VER - assert odbc_ver is not None, "Failed to retrieve ODBC version" - print(odbc_ver) +def test_getinfo_type_consistency(db_connection): + """Test that getinfo returns consistent types for repeated calls.""" + + # Choose a few representative info types that don't depend on DBMS + info_types = [ + sql_const.SQL_DRIVER_NAME.value, + sql_const.SQL_MAX_COLUMN_NAME_LEN.value, + sql_const.SQL_TXN_CAPABLE.value, + sql_const.SQL_IDENTIFIER_QUOTE_CHAR.value + ] - # ODBC version should have a pattern like "nn.nn" or "nn.nn.nnnn" - import re - assert re.match(r"\d+\.\d+(\.\d+)*", odbc_ver), f"ODBC version '{odbc_ver}' not in expected format" - -def test_getinfo_numeric_constants(db_connection): - """Test that getinfo() properly returns numeric constants.""" - # SQL_MAX_CONCURRENT_ACTIVITIES should be a reasonable value > 0 - max_activities = db_connection.getinfo(1) # SQL_MAX_CONCURRENT_ACTIVITIES - assert isinstance(max_activities, (int, str)), "MAX_CONCURRENT_ACTIVITIES should be numeric or string" + for info_type in info_types: + # Call getinfo twice with the same info type + result1 = db_connection.getinfo(info_type) + result2 = db_connection.getinfo(info_type) + + # Results should be consistent in type and value + assert type(result1) == type(result2), f"Type inconsistency for info type {info_type}" + assert result1 == result2, f"Value inconsistency for info type {info_type}" + +def test_getinfo_standard_types(db_connection): + """Test a representative set of standard ODBC info types.""" - if isinstance(max_activities, int): - assert max_activities >= 1, "MAX_CONCURRENT_ACTIVITIES should be >= 1" + # Dictionary of common info types and their expected value types + # Avoid DBMS-specific info types + info_types = { + sql_const.SQL_ACCESSIBLE_TABLES.value: str, # "Y" or "N" + sql_const.SQL_DATA_SOURCE_NAME.value: str, # DSN + sql_const.SQL_TABLE_TERM.value: str, # Usually "table" + sql_const.SQL_PROCEDURES.value: str, # "Y" or "N" + sql_const.SQL_MAX_IDENTIFIER_LEN.value: int, # Max identifier length + sql_const.SQL_OUTER_JOINS.value: str, # "Y" or "N" + } - # SQL_TXN_CAPABLE should indicate transaction capability - txn_capable = db_connection.getinfo(46) # SQL_TXN_CAPABLE - assert isinstance(txn_capable, int), "TXN_CAPABLE should be an integer" - assert txn_capable > 0, "Driver should support transactions" \ No newline at end of file + for info_type, expected_type in info_types.items(): + try: + info_value = db_connection.getinfo(info_type) + + # Skip None values (unsupported by driver) + if info_value is None: + continue + + # Check type, allowing empty strings for string types + if expected_type == str: + assert isinstance(info_value, str), f"Info type {info_type} should return a string" + elif expected_type == int: + assert isinstance(info_value, int), f"Info type {info_type} should return an integer" + + except Exception as e: + # Log but don't fail - some drivers might not support all info types + print(f"Info type {info_type} failed: {e}") \ No newline at end of file From 95c7d94272e363e686366b76b875810cf12df1d8 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Tue, 16 Sep 2025 13:54:23 +0530 Subject: [PATCH 5/7] Resolving comments --- mssql_python/pybind/connection/connection.cpp | 116 ++++++++++++++---- 1 file changed, 95 insertions(+), 21 deletions(-) diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index 9c1263f16..04ad28778 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -10,6 +10,7 @@ #include #define SQL_COPT_SS_ACCESS_TOKEN 1256 // Custom attribute ID for access token +#define SQL_MAX_SMALL_INT 32767 // Maximum value for SQLSMALLINT static SqlHandlePtr getEnvHandle() { static SqlHandlePtr envHandle = []() -> SqlHandlePtr { @@ -323,45 +324,86 @@ py::object Connection::getInfo(SQLUSMALLINT infoType) const { LOG("Getting connection info for type {}", infoType); - // For string results - allocate a buffer - char charBuffer[1024] = {0}; - SQLSMALLINT stringLength = 0; + // Use a vector for dynamic sizing + std::vector buffer(1024, 0); + SQLSMALLINT actualLength = 0; SQLRETURN ret; - // First try to get the info as a string or binary data - ret = SQLGetInfo_ptr(_dbcHandle->get(), infoType, charBuffer, sizeof(charBuffer), &stringLength); + // First try to get the info - handle SQLSMALLINT size limit + SQLSMALLINT bufferSize = (buffer.size() <= SQL_MAX_SMALL_INT) + ? static_cast(buffer.size()) + : SQL_MAX_SMALL_INT; + + ret = SQLGetInfo_ptr(_dbcHandle->get(), infoType, buffer.data(), bufferSize, &actualLength); + + // If truncation occurred (actualLength >= bufferSize means truncation) + if (SQL_SUCCEEDED(ret) && actualLength >= bufferSize) { + // Resize buffer to the needed size (add 1 for null terminator) + buffer.resize(actualLength + 1, 0); + + // Call again with the larger buffer - handle SQLSMALLINT size limit again + bufferSize = (buffer.size() <= SQL_MAX_SMALL_INT) + ? static_cast(buffer.size()) + : SQL_MAX_SMALL_INT; + + ret = SQLGetInfo_ptr(_dbcHandle->get(), infoType, buffer.data(), bufferSize, &actualLength); + } + + // Check for errors if (!SQL_SUCCEEDED(ret)) { checkError(ret); } + // Note: This implementation assumes the ODBC driver handles any necessary + // endianness conversions between the database server and the client. + // Determine return type based on the InfoType - // String types usually have InfoType > 10000 + // String types usually have InfoType > 10000 or are specifically known string values if (infoType > 10000 || infoType == SQL_DATA_SOURCE_NAME || infoType == SQL_DBMS_NAME || infoType == SQL_DBMS_VER || infoType == SQL_DRIVER_NAME || - infoType == SQL_DRIVER_VER) { + infoType == SQL_DRIVER_VER || + // Add missing string types + infoType == SQL_IDENTIFIER_QUOTE_CHAR || + infoType == SQL_CATALOG_NAME_SEPARATOR || + infoType == SQL_CATALOG_TERM || + infoType == SQL_SCHEMA_TERM || + infoType == SQL_TABLE_TERM || + infoType == SQL_KEYWORDS || + infoType == SQL_PROCEDURE_TERM) { // Return as string - return py::str(charBuffer); + return py::str(buffer.data()); } else if (infoType == SQL_DRIVER_ODBC_VER || infoType == SQL_SERVER_NAME) { // Return as string - return py::str(charBuffer); + return py::str(buffer.data()); } else { - // For numeric types, we need to interpret the buffer based on the expected return type - // Handle common numeric types + // For numeric types, use memcpy to safely extract the values + // This avoids potential alignment issues with direct casting + + // Ensure buffer has enough data for the expected type switch (infoType) { // 16-bit unsigned integers case SQL_MAX_CONCURRENT_ACTIVITIES: case SQL_MAX_DRIVER_CONNECTIONS: case SQL_ODBC_API_CONFORMANCE: case SQL_ODBC_SQL_CONFORMANCE: + case SQL_TXN_CAPABLE: // Add missing numeric types + case SQL_MULTIPLE_ACTIVE_TXN: + case SQL_MAX_COLUMN_NAME_LEN: + case SQL_MAX_TABLE_NAME_LEN: + case SQL_PROCEDURES: { - SQLUSMALLINT value = *reinterpret_cast(charBuffer); - return py::int_(value); + if (actualLength >= sizeof(SQLUSMALLINT)) { + SQLUSMALLINT value; + std::memcpy(&value, buffer.data(), sizeof(SQLUSMALLINT)); + return py::int_(value); + } + break; } // 32-bit unsigned integers @@ -379,9 +421,15 @@ py::object Connection::getInfo(SQLUSMALLINT infoType) const { case SQL_STRING_FUNCTIONS: case SQL_SYSTEM_FUNCTIONS: case SQL_TIMEDATE_FUNCTIONS: + case SQL_DEFAULT_TXN_ISOLATION: // Add missing numeric types + case SQL_MAX_STATEMENT_LEN: { - SQLUINTEGER value = *reinterpret_cast(charBuffer); - return py::int_(value); + if (actualLength >= sizeof(SQLUINTEGER)) { + SQLUINTEGER value; + std::memcpy(&value, buffer.data(), sizeof(SQLUINTEGER)); + return py::int_(value); + } + break; } // Boolean flags (32-bit mask) @@ -401,18 +449,44 @@ py::object Connection::getInfo(SQLUSMALLINT infoType) const { case SQL_STATIC_CURSOR_ATTRIBUTES1: case SQL_STATIC_CURSOR_ATTRIBUTES2: { - SQLUINTEGER value = *reinterpret_cast(charBuffer); - return py::int_(value); + if (actualLength >= sizeof(SQLUINTEGER)) { + SQLUINTEGER value; + std::memcpy(&value, buffer.data(), sizeof(SQLUINTEGER)); + return py::int_(value); + } + break; } - // Handle any other types as integers + // Handle any other types as integers, if enough data default: - SQLUINTEGER value = *reinterpret_cast(charBuffer); - return py::int_(value); + if (actualLength >= sizeof(SQLUINTEGER)) { + SQLUINTEGER value; + std::memcpy(&value, buffer.data(), sizeof(SQLUINTEGER)); + return py::int_(value); + } + else if (actualLength >= sizeof(SQLUSMALLINT)) { + SQLUSMALLINT value; + std::memcpy(&value, buffer.data(), sizeof(SQLUSMALLINT)); + return py::int_(value); + } + // For very small integers (like bytes/chars) + else if (actualLength > 0) { + // Try to interpret as a small integer + unsigned char value; + std::memcpy(&value, buffer.data(), sizeof(unsigned char)); + return py::int_(value); + } + break; } } - // Default return in case nothing matched + // If we get here and actualLength > 0, try to return as string as a last resort + if (actualLength > 0) { + return py::str(buffer.data()); + } + + // Default return in case nothing matched or buffer is too small + LOG("Unable to convert result for info type {}", infoType); return py::none(); } From 5df084f9f2d523889e3d9e093d8cefbe29dbf93e Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar <61936179+jahnvi480@users.noreply.github.com> Date: Thu, 18 Sep 2025 18:52:16 +0530 Subject: [PATCH 6/7] FEAT: Adding searchescape attribute for conn class (#202) ### Work Item / Issue Reference > [AB#34909](https://sqlclientdrivers.visualstudio.com/c6d89619-62de-46a0-8b46-70b92a84d85e/_workitems/edit/34909) ------------------------------------------------------------------- ### Summary This pull request adds a new `searchescape` property to the `Connection` class in `mssql_python`, which exposes the ODBC search pattern escape character used for escaping special characters in SQL LIKE clauses. It also introduces comprehensive tests to ensure correct and consistent behavior of this property, including its use in various SQL queries and its caching mechanism. Enhancements to connection escape character handling: * Added a `searchescape` property to the `Connection` class in `connection.py`, which retrieves and caches the ODBC search pattern escape character using `SQLGetInfo`, with sensible defaults and error handling. * Imported `GetInfoConstants` in `connection.py` to support the new property. Testing improvements: * Added multiple tests in `test_003_connection.py` to verify the `searchescape` property's basic functionality, its use in SQL LIKE queries (with `%`, `_`, and bracket wildcards), multiple escape scenarios, and property consistency and caching. --------- Co-authored-by: Jahnvi Thakkar --- mssql_python/connection.py | 25 ++++++ tests/test_003_connection.py | 165 ++++++++++++++++++++++++++++++++++- 2 files changed, 189 insertions(+), 1 deletion(-) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 2a7af1464..cc1701fb0 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -18,6 +18,7 @@ from mssql_python.pooling import PoolingManager from mssql_python.exceptions import InterfaceError from mssql_python.auth import process_connection_string +from mssql_python.constants import GetInfoConstants class Connection: @@ -159,6 +160,30 @@ def setautocommit(self, value: bool = False) -> None: """ self._conn.set_autocommit(value) + @property + def searchescape(self): + """ + The ODBC search pattern escape character, as returned by + SQLGetInfo(SQL_SEARCH_PATTERN_ESCAPE), used to escape special characters + such as '%' and '_' in LIKE clauses. These are driver specific. + + Returns: + str: The search pattern escape character (usually '\' or another character) + """ + if not hasattr(self, '_searchescape'): + try: + escape_char = self.getinfo(GetInfoConstants.SQL_SEARCH_PATTERN_ESCAPE.value) + # Some drivers might return this as an integer memory address + # or other non-string format, so ensure we have a string + if not isinstance(escape_char, str): + escape_char = '\\' # Default to backslash if not a string + self._searchescape = escape_char + except Exception as e: + # Log the exception for debugging, but do not expose sensitive info + log('warning', f"Failed to retrieve search escape character, using default '\\'. Exception: {type(e).__name__}") + self._searchescape = '\\' + return self._searchescape + def cursor(self) -> Cursor: """ Return a new Cursor object using the connection. diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index 8ddb33ef0..4937f6853 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -698,4 +698,167 @@ def test_getinfo_standard_types(db_connection): except Exception as e: # Log but don't fail - some drivers might not support all info types - print(f"Info type {info_type} failed: {e}") \ No newline at end of file + print(f"Info type {info_type} failed: {e}") + +def test_connection_searchescape_basic(db_connection): + """Test the basic functionality of the searchescape property.""" + # Get the search escape character + escape_char = db_connection.searchescape + + # Verify it's not None + assert escape_char is not None, "Search escape character should not be None" + print(f"Search pattern escape character: '{escape_char}'") + + # Test property caching - calling it twice should return the same value + escape_char2 = db_connection.searchescape + assert escape_char == escape_char2, "Search escape character should be consistent" + +def test_connection_searchescape_with_percent(db_connection): + """Test using the searchescape property with percent wildcard.""" + escape_char = db_connection.searchescape + + # Skip test if we got a non-string or empty escape character + if not isinstance(escape_char, str) or not escape_char: + pytest.skip("No valid escape character available for testing") + + cursor = db_connection.cursor() + try: + # Create a temporary table with data containing % character + cursor.execute("CREATE TABLE #test_escape_percent (id INT, text VARCHAR(50))") + cursor.execute("INSERT INTO #test_escape_percent VALUES (1, 'abc%def')") + cursor.execute("INSERT INTO #test_escape_percent VALUES (2, 'abc_def')") + cursor.execute("INSERT INTO #test_escape_percent VALUES (3, 'abcdef')") + + # Use the escape character to find the exact % character + query = f"SELECT * FROM #test_escape_percent WHERE text LIKE 'abc{escape_char}%def' ESCAPE '{escape_char}'" + cursor.execute(query) + results = cursor.fetchall() + + # Should match only the row with the % character + assert len(results) == 1, f"Escaped LIKE query for % matched {len(results)} rows instead of 1" + if results: + assert 'abc%def' in results[0][1], "Escaped LIKE query did not match correct row" + + except Exception as e: + print(f"Note: LIKE escape test with % failed: {e}") + # Don't fail the test as some drivers might handle escaping differently + finally: + cursor.execute("DROP TABLE #test_escape_percent") + +def test_connection_searchescape_with_underscore(db_connection): + """Test using the searchescape property with underscore wildcard.""" + escape_char = db_connection.searchescape + + # Skip test if we got a non-string or empty escape character + if not isinstance(escape_char, str) or not escape_char: + pytest.skip("No valid escape character available for testing") + + cursor = db_connection.cursor() + try: + # Create a temporary table with data containing _ character + cursor.execute("CREATE TABLE #test_escape_underscore (id INT, text VARCHAR(50))") + cursor.execute("INSERT INTO #test_escape_underscore VALUES (1, 'abc_def')") + cursor.execute("INSERT INTO #test_escape_underscore VALUES (2, 'abcXdef')") # 'X' could match '_' + cursor.execute("INSERT INTO #test_escape_underscore VALUES (3, 'abcdef')") # No match + + # Use the escape character to find the exact _ character + query = f"SELECT * FROM #test_escape_underscore WHERE text LIKE 'abc{escape_char}_def' ESCAPE '{escape_char}'" + cursor.execute(query) + results = cursor.fetchall() + + # Should match only the row with the _ character + assert len(results) == 1, f"Escaped LIKE query for _ matched {len(results)} rows instead of 1" + if results: + assert 'abc_def' in results[0][1], "Escaped LIKE query did not match correct row" + + except Exception as e: + print(f"Note: LIKE escape test with _ failed: {e}") + # Don't fail the test as some drivers might handle escaping differently + finally: + cursor.execute("DROP TABLE #test_escape_underscore") + +def test_connection_searchescape_with_brackets(db_connection): + """Test using the searchescape property with bracket wildcards.""" + escape_char = db_connection.searchescape + + # Skip test if we got a non-string or empty escape character + if not isinstance(escape_char, str) or not escape_char: + pytest.skip("No valid escape character available for testing") + + cursor = db_connection.cursor() + try: + # Create a temporary table with data containing [ character + cursor.execute("CREATE TABLE #test_escape_brackets (id INT, text VARCHAR(50))") + cursor.execute("INSERT INTO #test_escape_brackets VALUES (1, 'abc[x]def')") + cursor.execute("INSERT INTO #test_escape_brackets VALUES (2, 'abcxdef')") + + # Use the escape character to find the exact [ character + # Note: This might not work on all drivers as bracket escaping varies + query = f"SELECT * FROM #test_escape_brackets WHERE text LIKE 'abc{escape_char}[x{escape_char}]def' ESCAPE '{escape_char}'" + cursor.execute(query) + results = cursor.fetchall() + + # Just check we got some kind of result without asserting specific behavior + print(f"Bracket escaping test returned {len(results)} rows") + + except Exception as e: + print(f"Note: LIKE escape test with brackets failed: {e}") + # Don't fail the test as bracket escaping varies significantly between drivers + finally: + cursor.execute("DROP TABLE #test_escape_brackets") + +def test_connection_searchescape_multiple_escapes(db_connection): + """Test using the searchescape property with multiple escape sequences.""" + escape_char = db_connection.searchescape + + # Skip test if we got a non-string or empty escape character + if not isinstance(escape_char, str) or not escape_char: + pytest.skip("No valid escape character available for testing") + + cursor = db_connection.cursor() + try: + # Create a temporary table with data containing multiple special chars + cursor.execute("CREATE TABLE #test_multiple_escapes (id INT, text VARCHAR(50))") + cursor.execute("INSERT INTO #test_multiple_escapes VALUES (1, 'abc%def_ghi')") + cursor.execute("INSERT INTO #test_multiple_escapes VALUES (2, 'abc%defXghi')") # Wouldn't match the pattern + cursor.execute("INSERT INTO #test_multiple_escapes VALUES (3, 'abcXdef_ghi')") # Wouldn't match the pattern + + # Use escape character for both % and _ + query = f""" + SELECT * FROM #test_multiple_escapes + WHERE text LIKE 'abc{escape_char}%def{escape_char}_ghi' ESCAPE '{escape_char}' + """ + cursor.execute(query) + results = cursor.fetchall() + + # Should match only the row with both % and _ + assert len(results) <= 1, f"Multiple escapes query matched {len(results)} rows instead of at most 1" + if len(results) == 1: + assert 'abc%def_ghi' in results[0][1], "Multiple escapes query matched incorrect row" + + except Exception as e: + print(f"Note: Multiple escapes test failed: {e}") + # Don't fail the test as escaping behavior varies + finally: + cursor.execute("DROP TABLE #test_multiple_escapes") + +def test_connection_searchescape_consistency(db_connection): + """Test that the searchescape property is cached and consistent.""" + # Call the property multiple times + escape1 = db_connection.searchescape + escape2 = db_connection.searchescape + escape3 = db_connection.searchescape + + # All calls should return the same value + assert escape1 == escape2 == escape3, "Searchescape property should be consistent" + + # Create a new connection and verify it returns the same escape character + # (assuming the same driver and connection settings) + if 'conn_str' in globals(): + try: + new_conn = connect(conn_str) + new_escape = new_conn.searchescape + assert new_escape == escape1, "Searchescape should be consistent across connections" + new_conn.close() + except Exception as e: + print(f"Note: New connection comparison failed: {e}") \ No newline at end of file From 76ca4d833d83b41bdc3422503ed49215f8d67a35 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Thu, 18 Sep 2025 23:46:12 +0530 Subject: [PATCH 7/7] Resolving conflicts --- mssql_python/constants.py | 2 +- mssql_python/pybind/connection/connection.cpp | 55 +++++++++++-------- tests/test_003_connection.py | 2 +- 3 files changed, 35 insertions(+), 24 deletions(-) diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 413ae94ec..2e56112e5 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -177,7 +177,7 @@ class GetInfoConstants(Enum): SQL_CATALOG_USAGE = 92 SQL_SCHEMA_USAGE = 91 SQL_COLUMN_ALIAS = 87 - SQL_DESCRIBE_PARAMETER = 10002 + SQL_DESCRIBE_PARAMETER = 10003 # Transaction support SQL_TXN_CAPABLE = 46 diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index 04ad28778..d28ee7c50 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -382,8 +382,7 @@ py::object Connection::getInfo(SQLUSMALLINT infoType) const { return py::str(buffer.data()); } else { - // For numeric types, use memcpy to safely extract the values - // This avoids potential alignment issues with direct casting + // For numeric types, safely extract values // Ensure buffer has enough data for the expected type switch (infoType) { @@ -392,15 +391,17 @@ py::object Connection::getInfo(SQLUSMALLINT infoType) const { case SQL_MAX_DRIVER_CONNECTIONS: case SQL_ODBC_API_CONFORMANCE: case SQL_ODBC_SQL_CONFORMANCE: - case SQL_TXN_CAPABLE: // Add missing numeric types + case SQL_TXN_CAPABLE: case SQL_MULTIPLE_ACTIVE_TXN: case SQL_MAX_COLUMN_NAME_LEN: case SQL_MAX_TABLE_NAME_LEN: case SQL_PROCEDURES: { - if (actualLength >= sizeof(SQLUSMALLINT)) { - SQLUSMALLINT value; - std::memcpy(&value, buffer.data(), sizeof(SQLUSMALLINT)); + if (actualLength >= sizeof(SQLUSMALLINT) && buffer.size() >= sizeof(SQLUSMALLINT)) { + SQLUSMALLINT value = 0; + // Safely copy data by using std::copy instead of memcpy + std::copy(buffer.begin(), buffer.begin() + sizeof(SQLUSMALLINT), + reinterpret_cast(&value)); return py::int_(value); } break; @@ -421,12 +422,14 @@ py::object Connection::getInfo(SQLUSMALLINT infoType) const { case SQL_STRING_FUNCTIONS: case SQL_SYSTEM_FUNCTIONS: case SQL_TIMEDATE_FUNCTIONS: - case SQL_DEFAULT_TXN_ISOLATION: // Add missing numeric types + case SQL_DEFAULT_TXN_ISOLATION: case SQL_MAX_STATEMENT_LEN: { - if (actualLength >= sizeof(SQLUINTEGER)) { - SQLUINTEGER value; - std::memcpy(&value, buffer.data(), sizeof(SQLUINTEGER)); + if (actualLength >= sizeof(SQLUINTEGER) && buffer.size() >= sizeof(SQLUINTEGER)) { + SQLUINTEGER value = 0; + // Safely copy data by using std::copy instead of memcpy + std::copy(buffer.begin(), buffer.begin() + sizeof(SQLUINTEGER), + reinterpret_cast(&value)); return py::int_(value); } break; @@ -449,9 +452,11 @@ py::object Connection::getInfo(SQLUSMALLINT infoType) const { case SQL_STATIC_CURSOR_ATTRIBUTES1: case SQL_STATIC_CURSOR_ATTRIBUTES2: { - if (actualLength >= sizeof(SQLUINTEGER)) { - SQLUINTEGER value; - std::memcpy(&value, buffer.data(), sizeof(SQLUINTEGER)); + if (actualLength >= sizeof(SQLUINTEGER) && buffer.size() >= sizeof(SQLUINTEGER)) { + SQLUINTEGER value = 0; + // Safely copy data by using std::copy instead of memcpy + std::copy(buffer.begin(), buffer.begin() + sizeof(SQLUINTEGER), + reinterpret_cast(&value)); return py::int_(value); } break; @@ -459,21 +464,27 @@ py::object Connection::getInfo(SQLUSMALLINT infoType) const { // Handle any other types as integers, if enough data default: - if (actualLength >= sizeof(SQLUINTEGER)) { - SQLUINTEGER value; - std::memcpy(&value, buffer.data(), sizeof(SQLUINTEGER)); + if (actualLength >= sizeof(SQLUINTEGER) && buffer.size() >= sizeof(SQLUINTEGER)) { + SQLUINTEGER value = 0; + // Safely copy data by using std::copy instead of memcpy + std::copy(buffer.begin(), buffer.begin() + sizeof(SQLUINTEGER), + reinterpret_cast(&value)); return py::int_(value); } - else if (actualLength >= sizeof(SQLUSMALLINT)) { - SQLUSMALLINT value; - std::memcpy(&value, buffer.data(), sizeof(SQLUSMALLINT)); + else if (actualLength >= sizeof(SQLUSMALLINT) && buffer.size() >= sizeof(SQLUSMALLINT)) { + SQLUSMALLINT value = 0; + // Safely copy data by using std::copy instead of memcpy + std::copy(buffer.begin(), buffer.begin() + sizeof(SQLUSMALLINT), + reinterpret_cast(&value)); return py::int_(value); } // For very small integers (like bytes/chars) - else if (actualLength > 0) { + else if (actualLength > 0 && buffer.size() >= sizeof(unsigned char)) { // Try to interpret as a small integer - unsigned char value; - std::memcpy(&value, buffer.data(), sizeof(unsigned char)); + unsigned char value = 0; + // Safely copy data by using std::copy instead of memcpy + std::copy(buffer.begin(), buffer.begin() + sizeof(unsigned char), + reinterpret_cast(&value)); return py::int_(value); } break; diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index 1f7968e8c..fe2625c7f 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -4831,7 +4831,7 @@ def test_getinfo_catalog_support(db_connection): try: # Catalog support for tables catalog_term = db_connection.getinfo(sql_const.SQL_CATALOG_TERM.value) - print("Catalof term = ",catalog_term) + print("Catalog term = ",catalog_term) assert catalog_term is not None, "Catalog term should not be None" # Catalog name separator