diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 892dead91..2a81009b1 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -14,6 +14,7 @@ import re import codecs from typing import Any +import threading from mssql_python.cursor import Cursor from mssql_python.helpers import add_driver_to_connection_str, sanitize_connection_string, sanitize_user_input, log from mssql_python import ddbc_bindings @@ -187,6 +188,10 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef # TODO: Think and implement scenarios for multi-threaded access to cursors self._cursors = weakref.WeakSet() + # Initialize output converters dictionary and its lock for thread safety + self._output_converters = {} + self._converters_lock = threading.Lock() + # Auto-enable pooling if user never called if not PoolingManager.is_initialized(): PoolingManager.enable() @@ -531,6 +536,92 @@ def cursor(self) -> Cursor: cursor = Cursor(self) self._cursors.add(cursor) # Track the cursor return cursor + + def add_output_converter(self, sqltype, func) -> None: + """ + Register an output converter function that will be called whenever a value + with the given SQL type is read from the database. + + Thread-safe implementation that protects the converters dictionary with a lock. + + ⚠️ WARNING: Registering an output converter will cause the supplied Python function + to be executed on every matching database value. Do not register converters from + untrusted sources, as this can result in arbitrary code execution and security + vulnerabilities. This API should never be exposed to untrusted or external input. + + Args: + sqltype (int): The integer SQL type value to convert, which can be one of the + defined standard constants (e.g. SQL_VARCHAR) or a database-specific + value (e.g. -151 for the SQL Server 2008 geometry data type). + func (callable): The converter function which will be called with a single parameter, + the value, and should return the converted value. If the value is NULL + then the parameter passed to the function will be None, otherwise it + will be a bytes object. + + Returns: + None + """ + with self._converters_lock: + self._output_converters[sqltype] = func + # Pass to the underlying connection if native implementation supports it + if hasattr(self._conn, 'add_output_converter'): + self._conn.add_output_converter(sqltype, func) + log('info', f"Added output converter for SQL type {sqltype}") + + def get_output_converter(self, sqltype): + """ + Get the output converter function for the specified SQL type. + + Thread-safe implementation that protects the converters dictionary with a lock. + + Args: + sqltype (int or type): The SQL type value or Python type to get the converter for + + Returns: + callable or None: The converter function or None if no converter is registered + + Note: + ⚠️ The returned converter function will be executed on database values. Only use + converters from trusted sources. + """ + with self._converters_lock: + return self._output_converters.get(sqltype) + + def remove_output_converter(self, sqltype): + """ + Remove the output converter function for the specified SQL type. + + Thread-safe implementation that protects the converters dictionary with a lock. + + Args: + sqltype (int or type): The SQL type value to remove the converter for + + Returns: + None + """ + with self._converters_lock: + if sqltype in self._output_converters: + del self._output_converters[sqltype] + # Pass to the underlying connection if native implementation supports it + if hasattr(self._conn, 'remove_output_converter'): + self._conn.remove_output_converter(sqltype) + log('info', f"Removed output converter for SQL type {sqltype}") + + def clear_output_converters(self) -> None: + """ + Remove all output converter functions. + + Thread-safe implementation that protects the converters dictionary with a lock. + + Returns: + None + """ + with self._converters_lock: + self._output_converters.clear() + # Pass to the underlying connection if native implementation supports it + if hasattr(self._conn, 'clear_output_converters'): + self._conn.clear_output_converters() + log('info', "Cleared all output converters") def execute(self, sql: str, *args: Any) -> Cursor: """ diff --git a/mssql_python/row.py b/mssql_python/row.py index c7522fbf5..5e749a67c 100644 --- a/mssql_python/row.py +++ b/mssql_python/row.py @@ -24,7 +24,13 @@ def __init__(self, cursor, description, values, column_map=None): column_map: Optional pre-built column map (for optimization) """ self._cursor = cursor - self._values = values + self._description = description + + # Apply output converters if available + if hasattr(cursor.connection, '_output_converters') and cursor.connection._output_converters: + self._values = self._apply_output_converters(values) + else: + self._values = values # TODO: ADO task - Optimize memory usage by sharing column map across rows # Instead of storing the full cursor_description in each Row object: @@ -42,6 +48,57 @@ def __init__(self, cursor, description, values, column_map=None): self._column_map = column_map + def _apply_output_converters(self, values): + """ + Apply output converters to raw values. + + Args: + values: Raw values from the database + + Returns: + List of converted values + """ + if not self._description: + return values + + converted_values = list(values) + + for i, (value, desc) in enumerate(zip(values, self._description)): + if desc is None or value is None: + continue + + # Get SQL type from description + sql_type = desc[1] # type_code is at index 1 in description tuple + + # Try to get a converter for this type + converter = self._cursor.connection.get_output_converter(sql_type) + + # If no converter found for the SQL type but the value is a string or bytes, + # try the WVARCHAR converter as a fallback + if converter is None and isinstance(value, (str, bytes)): + from mssql_python.constants import ConstantsDDBC + converter = self._cursor.connection.get_output_converter(ConstantsDDBC.SQL_WVARCHAR.value) + + # If we found a converter, apply it + if converter: + try: + # If value is already a Python type (str, int, etc.), + # we need to convert it to bytes for our converters + if isinstance(value, str): + # Encode as UTF-16LE for string values (SQL_WVARCHAR format) + value_bytes = value.encode('utf-16-le') + converted_values[i] = converter(value_bytes) + else: + converted_values[i] = converter(value) + except Exception: + # Log the exception for debugging without leaking sensitive data + if hasattr(self._cursor, 'log'): + self._cursor.log('debug', 'Exception occurred in output converter', exc_info=True) + # If conversion fails, keep the original value + pass + + return converted_values + def __getitem__(self, index): """Allow accessing by numeric index: row[0]""" return self._values[index] diff --git a/requirements.txt b/requirements.txt index a4312a3dc..5abf13dc0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ pybind11 coverage unittest-xml-reporting setuptools +psutil \ No newline at end of file diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index f8b439fb8..902963a2e 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -23,7 +23,6 @@ from mssql_python.exceptions import InterfaceError, ProgrammingError import mssql_python -import datetime import pytest import time from mssql_python import connect, Connection, pooling, SQL_CHAR, SQL_WCHAR @@ -43,6 +42,9 @@ ProgrammingError, NotSupportedError, ) +import struct +from datetime import datetime, timedelta, timezone +from mssql_python.constants import ConstantsDDBC @pytest.fixture(autouse=True) def clean_connection_state(db_connection): @@ -74,6 +76,31 @@ def drop_table_if_exists(cursor, table_name): except Exception as e: pytest.fail(f"Failed to drop table {table_name}: {e}") +# Add these helper functions after other helper functions +def handle_datetimeoffset(dto_value): + """Converter function for SQL Server's DATETIMEOFFSET type""" + if dto_value is None: + return None + + # The format depends on the ODBC driver and how it returns binary data + # This matches SQL Server's format for DATETIMEOFFSET + tup = struct.unpack("<6hI2h", dto_value) # e.g., (2017, 3, 16, 10, 35, 18, 500000000, -6, 0) + return datetime( + tup[0], tup[1], tup[2], tup[3], tup[4], tup[5], tup[6] // 1000, + timezone(timedelta(hours=tup[7], minutes=tup[8])) + ) + +def custom_string_converter(value): + """ + A simple converter that adds a prefix to string values. + Assumes SQL_WVARCHAR is UTF-16LE encoded by default, + but this may vary depending on the database configuration. + You can specify a different encoding if needed. + """ + if value is None: + return None + return "CONVERTED: " + value.decode('utf-16-le') # SQL_WVARCHAR is UTF-16LE encoded + def test_connection_string(conn_str): # Check if the connection string is not None assert conn_str is not None, "Connection string should not be None" @@ -3320,7 +3347,7 @@ def test_batch_execute_basic(db_connection): assert results[1][0][0] == 'test', "Second result should be 'test'" assert len(results[2]) == 1, "Expected 1 row in third result" - assert isinstance(results[2][0][0], (str, datetime.datetime)), "Third result should be a date" + assert isinstance(results[2][0][0], (str, datetime)), "Third result should be a date" # Cursor should be usable after batch execution cursor.execute("SELECT 2 AS another_value") @@ -3616,4 +3643,436 @@ def test_batch_execute_large_batch(db_connection): assert results[25][0][0] == 25, "Middle result should be 25" assert results[49][0][0] == 49, "Last result should be 49" - cursor.close() \ No newline at end of file + cursor.close() +def test_connection_execute(db_connection): + """Test the execute() convenience method for Connection class""" + # Test basic execution + cursor = db_connection.execute("SELECT 1 AS test_value") + result = cursor.fetchone() + assert result is not None, "Execute failed: No result returned" + assert result[0] == 1, "Execute failed: Incorrect result" + + # Test with parameters + cursor = db_connection.execute("SELECT ? AS test_value", 42) + result = cursor.fetchone() + assert result is not None, "Execute with parameters failed: No result returned" + assert result[0] == 42, "Execute with parameters failed: Incorrect result" + + # Test that cursor is tracked by connection + assert cursor in db_connection._cursors, "Cursor from execute() not tracked by connection" + + # Test with data modification and verify it requires commit + if not db_connection.autocommit: + drop_table_if_exists(db_connection.cursor(), "#pytest_test_execute") + cursor1 = db_connection.execute("CREATE TABLE #pytest_test_execute (id INT, value VARCHAR(50))") + cursor2 = db_connection.execute("INSERT INTO #pytest_test_execute VALUES (1, 'test_value')") + cursor3 = db_connection.execute("SELECT * FROM #pytest_test_execute") + result = cursor3.fetchone() + assert result is not None, "Execute with table creation failed" + assert result[0] == 1, "Execute with table creation returned wrong id" + assert result[1] == 'test_value', "Execute with table creation returned wrong value" + + # Clean up + db_connection.execute("DROP TABLE #pytest_test_execute") + db_connection.commit() + +def test_connection_execute_error_handling(db_connection): + """Test that execute() properly handles SQL errors""" + with pytest.raises(Exception): + db_connection.execute("SELECT * FROM nonexistent_table") + +def test_connection_execute_empty_result(db_connection): + """Test execute() with a query that returns no rows""" + cursor = db_connection.execute("SELECT * FROM sys.tables WHERE name = 'nonexistent_table_name'") + result = cursor.fetchone() + assert result is None, "Query should return no results" + + # Test empty result with fetchall + rows = cursor.fetchall() + assert len(rows) == 0, "fetchall should return empty list for empty result set" + +def test_connection_execute_different_parameter_types(db_connection): + """Test execute() with different parameter data types""" + # Test with different data types + params = [ + 1234, # Integer + 3.14159, # Float + "test string", # String + bytearray(b'binary data'), # Binary data + True, # Boolean + None # NULL + ] + + for param in params: + cursor = db_connection.execute("SELECT ? AS value", param) + result = cursor.fetchone() + if param is None: + assert result[0] is None, "NULL parameter not handled correctly" + else: + assert result[0] == param, f"Parameter {param} of type {type(param)} not handled correctly" + +def test_connection_execute_with_transaction(db_connection): + """Test execute() in the context of explicit transactions""" + if db_connection.autocommit: + db_connection.autocommit = False + + cursor1 = db_connection.cursor() + drop_table_if_exists(cursor1, "#pytest_test_execute_transaction") + + try: + # Create table and insert data + db_connection.execute("CREATE TABLE #pytest_test_execute_transaction (id INT, value VARCHAR(50))") + db_connection.execute("INSERT INTO #pytest_test_execute_transaction VALUES (1, 'before rollback')") + + # Check data is there + cursor = db_connection.execute("SELECT * FROM #pytest_test_execute_transaction") + result = cursor.fetchone() + assert result is not None, "Data should be visible within transaction" + assert result[1] == 'before rollback', "Incorrect data in transaction" + + # Rollback and verify data is gone + db_connection.rollback() + + # Need to recreate table since it was rolled back + db_connection.execute("CREATE TABLE #pytest_test_execute_transaction (id INT, value VARCHAR(50))") + db_connection.execute("INSERT INTO #pytest_test_execute_transaction VALUES (2, 'after rollback')") + + cursor = db_connection.execute("SELECT * FROM #pytest_test_execute_transaction") + result = cursor.fetchone() + assert result is not None, "Data should be visible after new insert" + assert result[0] == 2, "Should see the new data after rollback" + assert result[1] == 'after rollback', "Incorrect data after rollback" + + # Commit and verify data persists + db_connection.commit() + finally: + # Clean up + try: + db_connection.execute("DROP TABLE #pytest_test_execute_transaction") + db_connection.commit() + except Exception: + pass + +def test_connection_execute_vs_cursor_execute(db_connection): + """Compare behavior of connection.execute() vs cursor.execute()""" + # Connection.execute creates a new cursor each time + cursor1 = db_connection.execute("SELECT 1 AS first_query") + # Consume the results from cursor1 before creating cursor2 + result1 = cursor1.fetchall() + assert result1[0][0] == 1, "First cursor should have result from first query" + + # Now it's safe to create a second cursor + cursor2 = db_connection.execute("SELECT 2 AS second_query") + result2 = cursor2.fetchall() + assert result2[0][0] == 2, "Second cursor should have result from second query" + + # These should be different cursor objects + assert cursor1 != cursor2, "Connection.execute should create a new cursor each time" + + # Now compare with reusing the same cursor + cursor3 = db_connection.cursor() + cursor3.execute("SELECT 3 AS third_query") + result3 = cursor3.fetchone() + assert result3[0] == 3, "Direct cursor execution failed" + + # Reuse the same cursor + cursor3.execute("SELECT 4 AS fourth_query") + result4 = cursor3.fetchone() + assert result4[0] == 4, "Reused cursor should have new results" + + # The previous results should no longer be accessible + cursor3.execute("SELECT 3 AS third_query_again") + result5 = cursor3.fetchone() + assert result5[0] == 3, "Cursor reexecution should work" + +def test_connection_execute_many_parameters(db_connection): + """Test execute() with many parameters""" + # First make sure no active results are pending + # by using a fresh cursor and fetching all results + cursor = db_connection.cursor() + cursor.execute("SELECT 1") + cursor.fetchall() + + # Create a query with 10 parameters + params = list(range(1, 11)) + query = "SELECT " + ", ".join(["?" for _ in params]) + " AS many_params" + + # Now execute with many parameters + cursor = db_connection.execute(query, *params) + result = cursor.fetchall() # Use fetchall to consume all results + + # Verify all parameters were correctly passed + for i, value in enumerate(params): + assert result[0][i] == value, f"Parameter at position {i} not correctly passed" + +def test_add_output_converter(db_connection): + """Test adding an output converter""" + # Add a converter + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + + # Verify it was added correctly + assert hasattr(db_connection, '_output_converters') + assert sql_wvarchar in db_connection._output_converters + assert db_connection._output_converters[sql_wvarchar] == custom_string_converter + + # Clean up + db_connection.clear_output_converters() + +def test_get_output_converter(db_connection): + """Test getting an output converter""" + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + + # Initial state - no converter + assert db_connection.get_output_converter(sql_wvarchar) is None + + # Add a converter + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + + # Get the converter + converter = db_connection.get_output_converter(sql_wvarchar) + assert converter == custom_string_converter + + # Get a non-existent converter + assert db_connection.get_output_converter(999) is None + + # Clean up + db_connection.clear_output_converters() + +def test_remove_output_converter(db_connection): + """Test removing an output converter""" + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + + # Add a converter + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + assert db_connection.get_output_converter(sql_wvarchar) is not None + + # Remove the converter + db_connection.remove_output_converter(sql_wvarchar) + assert db_connection.get_output_converter(sql_wvarchar) is None + + # Remove a non-existent converter (should not raise) + db_connection.remove_output_converter(999) + +def test_clear_output_converters(db_connection): + """Test clearing all output converters""" + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + sql_timestamp_offset = ConstantsDDBC.SQL_TIMESTAMPOFFSET.value + + # Add multiple converters + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + db_connection.add_output_converter(sql_timestamp_offset, handle_datetimeoffset) + + # Verify converters were added + assert db_connection.get_output_converter(sql_wvarchar) is not None + assert db_connection.get_output_converter(sql_timestamp_offset) is not None + + # Clear all converters + db_connection.clear_output_converters() + + # Verify all converters were removed + assert db_connection.get_output_converter(sql_wvarchar) is None + assert db_connection.get_output_converter(sql_timestamp_offset) is None + +def test_converter_integration(db_connection): + """ + Test that converters work during fetching. + + This test verifies that output converters work at the Python level + without requiring native driver support. + """ + cursor = db_connection.cursor() + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + + # Test with string converter + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + + # Test a simple string query + cursor.execute("SELECT N'test string' AS test_col") + row = cursor.fetchone() + + # Check if the type matches what we expect for SQL_WVARCHAR + # For Cursor.description, the second element is the type code + column_type = cursor.description[0][1] + + # If the cursor description has SQL_WVARCHAR as the type code, + # then our converter should be applied + if column_type == sql_wvarchar: + assert row[0].startswith("CONVERTED:"), "Output converter not applied" + else: + # If the type code is different, adjust the test or the converter + print(f"Column type is {column_type}, not {sql_wvarchar}") + # Add converter for the actual type used + db_connection.clear_output_converters() + db_connection.add_output_converter(column_type, custom_string_converter) + + # Re-execute the query + cursor.execute("SELECT N'test string' AS test_col") + row = cursor.fetchone() + assert row[0].startswith("CONVERTED:"), "Output converter not applied" + + # Clean up + db_connection.clear_output_converters() + +def test_output_converter_with_null_values(db_connection): + """Test that output converters handle NULL values correctly""" + cursor = db_connection.cursor() + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + + # Add converter for string type + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + + # Execute a query with NULL values + cursor.execute("SELECT CAST(NULL AS NVARCHAR(50)) AS null_col") + value = cursor.fetchone()[0] + + # NULL values should remain None regardless of converter + assert value is None + + # Clean up + db_connection.clear_output_converters() + +def test_chaining_output_converters(db_connection): + """Test that output converters can be chained (replaced)""" + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + + # Define a second converter + def another_string_converter(value): + if value is None: + return None + return "ANOTHER: " + value.decode('utf-16-le') + + # Add first converter + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + + # Verify first converter is registered + assert db_connection.get_output_converter(sql_wvarchar) == custom_string_converter + + # Replace with second converter + db_connection.add_output_converter(sql_wvarchar, another_string_converter) + + # Verify second converter replaced the first + assert db_connection.get_output_converter(sql_wvarchar) == another_string_converter + + # Clean up + db_connection.clear_output_converters() + +def test_temporary_converter_replacement(db_connection): + """Test temporarily replacing a converter and then restoring it""" + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + + # Add a converter + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + + # Save original converter + original_converter = db_connection.get_output_converter(sql_wvarchar) + + # Define a temporary converter + def temp_converter(value): + if value is None: + return None + return "TEMP: " + value.decode('utf-16-le') + + # Replace with temporary converter + db_connection.add_output_converter(sql_wvarchar, temp_converter) + + # Verify temporary converter is in use + assert db_connection.get_output_converter(sql_wvarchar) == temp_converter + + # Restore original converter + db_connection.add_output_converter(sql_wvarchar, original_converter) + + # Verify original converter is restored + assert db_connection.get_output_converter(sql_wvarchar) == original_converter + + # Clean up + db_connection.clear_output_converters() + +def test_multiple_output_converters(db_connection): + """Test that multiple output converters can work together""" + cursor = db_connection.cursor() + + # Execute a query to get the actual type codes used + cursor.execute("SELECT CAST(42 AS INT) as int_col, N'test' as str_col") + int_type = cursor.description[0][1] # Type code for integer column + str_type = cursor.description[1][1] # Type code for string column + + # Add converter for string type + db_connection.add_output_converter(str_type, custom_string_converter) + + # Add converter for integer type + def int_converter(value): + if value is None: + return None + # Convert from bytes to int and multiply by 2 + if isinstance(value, bytes): + return int.from_bytes(value, byteorder='little') * 2 + elif isinstance(value, int): + return value * 2 + return value + + db_connection.add_output_converter(int_type, int_converter) + + # Test query with both types + cursor.execute("SELECT CAST(42 AS INT) as int_col, N'test' as str_col") + row = cursor.fetchone() + + # Verify converters worked + assert row[0] == 84, f"Integer converter failed, got {row[0]} instead of 84" + assert isinstance(row[1], str) and "CONVERTED:" in row[1], f"String converter failed, got {row[1]}" + + # Clean up + db_connection.clear_output_converters() + +def test_output_converter_exception_handling(db_connection): + """Test that exceptions in output converters are properly handled""" + cursor = db_connection.cursor() + + # First determine the actual type code for NVARCHAR + cursor.execute("SELECT N'test string' AS test_col") + str_type = cursor.description[0][1] + + # Define a converter that will raise an exception + def faulty_converter(value): + if value is None: + return None + # Intentionally raise an exception with potentially sensitive info + # This simulates a bug in a custom converter + raise ValueError(f"Converter error with sensitive data: {value!r}") + + # Add the faulty converter + db_connection.add_output_converter(str_type, faulty_converter) + + try: + # Execute a query that will trigger the converter + cursor.execute("SELECT N'test string' AS test_col") + + # Attempt to fetch data, which should trigger the converter + row = cursor.fetchone() + + # The implementation could handle this in different ways: + # 1. Fall back to returning the unconverted value + # 2. Return None for the problematic column + # 3. Raise a sanitized exception + + # If we got here, the exception was caught and handled internally + assert row is not None, "Row should still be returned despite converter error" + assert row[0] is not None, "Column value shouldn't be None despite converter error" + + # Verify we can continue using the connection + cursor.execute("SELECT 1 AS test") + assert cursor.fetchone()[0] == 1, "Connection should still be usable" + + except Exception as e: + # If an exception is raised, ensure it doesn't contain the sensitive info + error_str = str(e) + assert "sensitive data" not in error_str, f"Exception leaked sensitive data: {error_str}" + assert not isinstance(e, ValueError), "Original exception type should not be exposed" + + # Verify we can continue using the connection after the error + cursor.execute("SELECT 1 AS test") + assert cursor.fetchone()[0] == 1, "Connection should still be usable after converter error" + + finally: + # Clean up + db_connection.clear_output_converters() \ No newline at end of file diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index a74e13837..5c0c5f318 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -1816,75 +1816,6 @@ def test_row_column_mapping(cursor, db_connection): cursor.execute("DROP TABLE #pytest_row_test") db_connection.commit() -def test_lowercase_attribute(cursor, db_connection): - """Test that the lowercase attribute properly converts column names to lowercase""" - - # Store original value to restore after test - original_lowercase = mssql_python.lowercase - drop_cursor = None - - try: - # Create a test table with mixed-case column names - cursor.execute(""" - CREATE TABLE #pytest_lowercase_test ( - ID INT PRIMARY KEY, - UserName VARCHAR(50), - EMAIL_ADDRESS VARCHAR(100), - PhoneNumber VARCHAR(20) - ) - """) - db_connection.commit() - - # Insert test data - cursor.execute(""" - INSERT INTO #pytest_lowercase_test (ID, UserName, EMAIL_ADDRESS, PhoneNumber) - VALUES (1, 'JohnDoe', 'john@example.com', '555-1234') - """) - db_connection.commit() - - # First test with lowercase=False (default) - mssql_python.lowercase = False - cursor1 = db_connection.cursor() - cursor1.execute("SELECT * FROM #pytest_lowercase_test") - - # Description column names should preserve original case - column_names1 = [desc[0] for desc in cursor1.description] - assert "ID" in column_names1, "Column 'ID' should be present with original case" - assert "UserName" in column_names1, "Column 'UserName' should be present with original case" - - # Make sure to consume all results and close the cursor - cursor1.fetchall() - cursor1.close() - - # Now test with lowercase=True - mssql_python.lowercase = True - cursor2 = db_connection.cursor() - cursor2.execute("SELECT * FROM #pytest_lowercase_test") - - # Description column names should be lowercase - column_names2 = [desc[0] for desc in cursor2.description] - assert "id" in column_names2, "Column names should be lowercase when lowercase=True" - assert "username" in column_names2, "Column names should be lowercase when lowercase=True" - - # Make sure to consume all results and close the cursor - cursor2.fetchall() - cursor2.close() - - # Create a fresh cursor for cleanup - drop_cursor = db_connection.cursor() - - finally: - # Restore original setting - mssql_python.lowercase = original_lowercase - # Clean up the table - if drop_cursor: - try: - drop_cursor.execute("DROP TABLE #pytest_lowercase_test") - db_connection.commit() - drop_cursor.close() - except Exception: - pass # Suppress errors during cleanup - def test_lowercase_setting_after_cursor_creation(cursor, db_connection): """Test that changing lowercase setting after cursor creation doesn't affect existing cursor""" original_lowercase = mssql_python.lowercase @@ -7010,83 +6941,6 @@ def test_money_smallmoney_invalid_values(cursor, db_connection): drop_table_if_exists(cursor, "dbo.money_test") db_connection.commit() -def test_decimal_separator_function(cursor, db_connection): - """Test decimal separator functionality with database operations""" - # Store original value to restore after test - original_separator = mssql_python.getDecimalSeparator() - - try: - # Create test table - cursor.execute(""" - CREATE TABLE #pytest_decimal_separator_test ( - id INT PRIMARY KEY, - decimal_value DECIMAL(10, 2) - ) - """) - db_connection.commit() - - # Insert test values with default separator (.) - test_value = decimal.Decimal('123.45') - cursor.execute(""" - INSERT INTO #pytest_decimal_separator_test (id, decimal_value) - VALUES (1, ?) - """, [test_value]) - db_connection.commit() - - # First test with default decimal separator (.) - cursor.execute("SELECT id, decimal_value FROM #pytest_decimal_separator_test") - row = cursor.fetchone() - default_str = str(row) - assert '123.45' in default_str, "Default separator not found in string representation" - - # Now change to comma separator and test string representation - mssql_python.setDecimalSeparator(',') - cursor.execute("SELECT id, decimal_value FROM #pytest_decimal_separator_test") - row = cursor.fetchone() - - # This should format the decimal with a comma in the string representation - comma_str = str(row) - assert '123,45' in comma_str, f"Expected comma in string representation but got: {comma_str}" - - finally: - # Restore original decimal separator - mssql_python.setDecimalSeparator(original_separator) - - # Cleanup - cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_separator_test") - db_connection.commit() - -def test_decimal_separator_basic_functionality(): - """Test basic decimal separator functionality without database operations""" - # Store original value to restore after test - original_separator = mssql_python.getDecimalSeparator() - - try: - # Test default value - assert mssql_python.getDecimalSeparator() == '.', "Default decimal separator should be '.'" - - # Test setting to comma - mssql_python.setDecimalSeparator(',') - assert mssql_python.getDecimalSeparator() == ',', "Decimal separator should be ',' after setting" - - # Test setting to other valid separators - mssql_python.setDecimalSeparator(':') - assert mssql_python.getDecimalSeparator() == ':', "Decimal separator should be ':' after setting" - - # Test invalid inputs - with pytest.raises(ValueError): - mssql_python.setDecimalSeparator('') # Empty string - - with pytest.raises(ValueError): - mssql_python.setDecimalSeparator('too_long') # More than one character - - with pytest.raises(ValueError): - mssql_python.setDecimalSeparator(123) # Not a string - - finally: - # Restore original separator - mssql_python.setDecimalSeparator(original_separator) - def test_decimal_separator_with_multiple_values(cursor, db_connection): """Test decimal separator with multiple different decimal values""" original_separator = mssql_python.getDecimalSeparator() @@ -7182,6 +7036,83 @@ def test_decimal_separator_calculations(cursor, db_connection): cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_calc_test") db_connection.commit() +def test_decimal_separator_function(cursor, db_connection): + """Test decimal separator functionality with database operations""" + # Store original value to restore after test + original_separator = mssql_python.getDecimalSeparator() + + try: + # Create test table + cursor.execute(""" + CREATE TABLE #pytest_decimal_separator_test ( + id INT PRIMARY KEY, + decimal_value DECIMAL(10, 2) + ) + """) + db_connection.commit() + + # Insert test values with default separator (.) + test_value = decimal.Decimal('123.45') + cursor.execute(""" + INSERT INTO #pytest_decimal_separator_test (id, decimal_value) + VALUES (1, ?) + """, [test_value]) + db_connection.commit() + + # First test with default decimal separator (.) + cursor.execute("SELECT id, decimal_value FROM #pytest_decimal_separator_test") + row = cursor.fetchone() + default_str = str(row) + assert '123.45' in default_str, "Default separator not found in string representation" + + # Now change to comma separator and test string representation + mssql_python.setDecimalSeparator(',') + cursor.execute("SELECT id, decimal_value FROM #pytest_decimal_separator_test") + row = cursor.fetchone() + + # This should format the decimal with a comma in the string representation + comma_str = str(row) + assert '123,45' in comma_str, f"Expected comma in string representation but got: {comma_str}" + + finally: + # Restore original decimal separator + mssql_python.setDecimalSeparator(original_separator) + + # Cleanup + cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_separator_test") + db_connection.commit() + +def test_decimal_separator_basic_functionality(): + """Test basic decimal separator functionality without database operations""" + # Store original value to restore after test + original_separator = mssql_python.getDecimalSeparator() + + try: + # Test default value + assert mssql_python.getDecimalSeparator() == '.', "Default decimal separator should be '.'" + + # Test setting to comma + mssql_python.setDecimalSeparator(',') + assert mssql_python.getDecimalSeparator() == ',', "Decimal separator should be ',' after setting" + + # Test setting to other valid separators + mssql_python.setDecimalSeparator(':') + assert mssql_python.getDecimalSeparator() == ':', "Decimal separator should be ':' after setting" + + # Test invalid inputs + with pytest.raises(ValueError): + mssql_python.setDecimalSeparator('') # Empty string + + with pytest.raises(ValueError): + mssql_python.setDecimalSeparator('too_long') # More than one character + + with pytest.raises(ValueError): + mssql_python.setDecimalSeparator(123) # Not a string + + finally: + # Restore original separator + mssql_python.setDecimalSeparator(original_separator) + def test_close(db_connection): """Test closing the cursor""" try: