From c55f7a439e10235c1cd1d9cbaf8009550e0fd1ea Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Thu, 7 Aug 2025 12:38:07 +0530 Subject: [PATCH 1/5] FEAT: Adding conn.setencoding() API --- mssql_python/connection.py | 96 ++++++++++- mssql_python/type.py | 2 +- tests/test_003_connection.py | 319 +++++++++++++++++++++++++++++++++++ 3 files changed, 415 insertions(+), 2 deletions(-) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 12760df41..b68fd75e1 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -12,12 +12,14 @@ """ import weakref import re +import codecs from mssql_python.cursor import Cursor from mssql_python.helpers import add_driver_to_connection_str, sanitize_connection_string, log from mssql_python import ddbc_bindings from mssql_python.pooling import PoolingManager -from mssql_python.exceptions import InterfaceError +from mssql_python.exceptions import InterfaceError, ProgrammingError from mssql_python.auth import process_connection_string +from mssql_python.constants import ConstantsDDBC class Connection: @@ -36,6 +38,7 @@ class Connection: commit() -> None: rollback() -> None: close() -> None: + setencoding(encoding=None, ctype=None) -> None: """ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_before: dict = None, **kwargs) -> None: @@ -63,6 +66,13 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef ) self._attrs_before = attrs_before or {} + # Initialize encoding settings with defaults for Python 3 + # Python 3 only has str (which is Unicode), so we use utf-16le by default + self._encoding_settings = { + 'encoding': 'utf-16le', + 'ctype': ConstantsDDBC.SQL_WCHAR.value + } + # Check if the connection string contains authentication parameters # This is important for processing the connection string correctly. # If authentication is specified, it will be processed to handle @@ -159,6 +169,90 @@ def setautocommit(self, value: bool = False) -> None: """ self._conn.set_autocommit(value) + def setencoding(self, encoding=None, ctype=None): + """ + Sets the text encoding for SQL statements and text parameters. + + Since Python 3 only has str (which is Unicode), this method configures + how text is encoded when sending to the database. + + Args: + encoding (str, optional): The encoding to use. This must be a valid Python + encoding that converts text to bytes. If None, defaults to 'utf-16le'. + ctype (int, optional): The C data type to use when passing data: + SQL_CHAR or SQL_WCHAR. If not provided, SQL_WCHAR is used for + "utf-16", "utf-16le", and "utf-16be". SQL_CHAR is used for all other encodings. + + Returns: + None + + Raises: + ProgrammingError: If the encoding is not valid or not supported. + InterfaceError: If the connection is closed. + + Example: + # For databases that only communicate with UTF-8 + cnxn.setencoding(encoding='utf-8') + + # For explicitly using SQL_CHAR + cnxn.setencoding(encoding='utf-8', ctype=mssql_python.SQL_CHAR) + """ + if self._closed: + raise InterfaceError( + driver_error="Cannot set encoding on closed connection", + ddbc_error="Cannot set encoding on closed connection", + ) + + # Set default encoding if not provided + if encoding is None: + encoding = 'utf-16le' + + # Validate encoding + try: + codecs.lookup(encoding) + except LookupError: + raise ProgrammingError( + driver_error=f"Unknown encoding: {encoding}", + ddbc_error=f"The encoding '{encoding}' is not supported by Python", + ) + + # Set default ctype based on encoding if not provided + if ctype is None: + if encoding.lower() in ('utf-16', 'utf-16le', 'utf-16be'): + ctype = ConstantsDDBC.SQL_WCHAR.value + else: + ctype = ConstantsDDBC.SQL_CHAR.value + + # Validate ctype + valid_ctypes = [ConstantsDDBC.SQL_CHAR.value, ConstantsDDBC.SQL_WCHAR.value] + if ctype not in valid_ctypes: + raise ProgrammingError( + driver_error=f"Invalid ctype: {ctype}", + ddbc_error=f"ctype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}) or SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value})", + ) + + # Store the encoding settings + self._encoding_settings = { + 'encoding': encoding, + 'ctype': ctype + } + + log('info', "Text encoding set to %s with ctype %s", encoding, ctype) + + def getencoding(self): + """ + Gets the current text encoding settings. + + Returns: + dict: A dictionary containing 'encoding' and 'ctype' keys. + + Example: + settings = cnxn.getencoding() + print(f"Current encoding: {settings['encoding']}") + print(f"Current ctype: {settings['ctype']}") + """ + return self._encoding_settings.copy() + def cursor(self) -> Cursor: """ Return a new Cursor object using the connection. diff --git a/mssql_python/type.py b/mssql_python/type.py index 0c9cfde61..69ecf2514 100644 --- a/mssql_python/type.py +++ b/mssql_python/type.py @@ -104,7 +104,7 @@ def Binary(value) -> bytes: """ Converts a string or bytes to bytes for use with binary database columns. - This function follows the DB-API 2.0 specification and pyodbc compatibility. + This function follows the DB-API 2.0 specification. It accepts only str and bytes/bytearray types to ensure type safety. Args: diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index c71e769b9..4fb6d3e9b 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -482,3 +482,322 @@ def test_connection_pooling_basic(conn_str): conn1.close() conn2.close() + +def test_setencoding_default_settings(db_connection): + """Test that default encoding settings are correct.""" + settings = db_connection.getencoding() + assert settings['encoding'] == 'utf-16le', "Default encoding should be utf-16le" + assert settings['ctype'] == -8, "Default ctype should be SQL_WCHAR (-8)" + +def test_setencoding_basic_functionality(db_connection): + """Test basic setencoding functionality.""" + # Test setting UTF-8 encoding + db_connection.setencoding(encoding='utf-8') + settings = db_connection.getencoding() + assert settings['encoding'] == 'utf-8', "Encoding should be set to utf-8" + assert settings['ctype'] == 1, "ctype should default to SQL_CHAR (1) for utf-8" + + # Test setting UTF-16LE with explicit ctype + db_connection.setencoding(encoding='utf-16le', ctype=-8) + settings = db_connection.getencoding() + assert settings['encoding'] == 'utf-16le', "Encoding should be set to utf-16le" + assert settings['ctype'] == -8, "ctype should be SQL_WCHAR (-8)" + +def test_setencoding_automatic_ctype_detection(db_connection): + """Test automatic ctype detection based on encoding.""" + # UTF-16 variants should default to SQL_WCHAR + utf16_encodings = ['utf-16', 'utf-16le', 'utf-16be'] + for encoding in utf16_encodings: + db_connection.setencoding(encoding=encoding) + settings = db_connection.getencoding() + assert settings['ctype'] == -8, f"{encoding} should default to SQL_WCHAR (-8)" + + # Other encodings should default to SQL_CHAR + other_encodings = ['utf-8', 'latin-1', 'ascii'] + for encoding in other_encodings: + db_connection.setencoding(encoding=encoding) + settings = db_connection.getencoding() + assert settings['ctype'] == 1, f"{encoding} should default to SQL_CHAR (1)" + +def test_setencoding_explicit_ctype_override(db_connection): + """Test that explicit ctype parameter overrides automatic detection.""" + # Set UTF-8 with SQL_WCHAR (override default) + db_connection.setencoding(encoding='utf-8', ctype=-8) + settings = db_connection.getencoding() + assert settings['encoding'] == 'utf-8', "Encoding should be utf-8" + assert settings['ctype'] == -8, "ctype should be SQL_WCHAR (-8) when explicitly set" + + # Set UTF-16LE with SQL_CHAR (override default) + db_connection.setencoding(encoding='utf-16le', ctype=1) + settings = db_connection.getencoding() + assert settings['encoding'] == 'utf-16le', "Encoding should be utf-16le" + assert settings['ctype'] == 1, "ctype should be SQL_CHAR (1) when explicitly set" + +def test_setencoding_none_parameters(db_connection): + """Test setencoding with None parameters.""" + # Test with encoding=None (should use default) + db_connection.setencoding(encoding=None) + settings = db_connection.getencoding() + assert settings['encoding'] == 'utf-16le', "encoding=None should use default utf-16le" + assert settings['ctype'] == -8, "ctype should be SQL_WCHAR for utf-16le" + + # Test with both None (should use defaults) + db_connection.setencoding(encoding=None, ctype=None) + settings = db_connection.getencoding() + assert settings['encoding'] == 'utf-16le', "encoding=None should use default utf-16le" + assert settings['ctype'] == -8, "ctype=None should use default SQL_WCHAR" + +def test_setencoding_invalid_encoding(db_connection): + """Test setencoding with invalid encoding.""" + from mssql_python.exceptions import ProgrammingError + + with pytest.raises(ProgrammingError) as exc_info: + db_connection.setencoding(encoding='invalid-encoding-name') + + assert "Unknown encoding" in str(exc_info.value), "Should raise ProgrammingError for invalid encoding" + assert "invalid-encoding-name" in str(exc_info.value), "Error message should include the invalid encoding name" + +def test_setencoding_invalid_ctype(db_connection): + """Test setencoding with invalid ctype.""" + from mssql_python.exceptions import ProgrammingError + + with pytest.raises(ProgrammingError) as exc_info: + db_connection.setencoding(encoding='utf-8', ctype=999) + + assert "Invalid ctype" in str(exc_info.value), "Should raise ProgrammingError for invalid ctype" + assert "999" in str(exc_info.value), "Error message should include the invalid ctype value" + +def test_setencoding_closed_connection(conn_str): + """Test setencoding on closed connection.""" + from mssql_python.exceptions import InterfaceError + + temp_conn = connect(conn_str) + temp_conn.close() + + with pytest.raises(InterfaceError) as exc_info: + temp_conn.setencoding(encoding='utf-8') + + assert "closed connection" in str(exc_info.value).lower(), "Should raise InterfaceError for closed connection" + +def test_setencoding_constants_access(): + """Test that SQL_CHAR and SQL_WCHAR constants are accessible.""" + import mssql_python + + # Test constants exist and have correct values + assert hasattr(mssql_python, 'SQL_CHAR'), "SQL_CHAR constant should be available" + assert hasattr(mssql_python, 'SQL_WCHAR'), "SQL_WCHAR constant should be available" + assert mssql_python.SQL_CHAR == 1, "SQL_CHAR should have value 1" + assert mssql_python.SQL_WCHAR == -8, "SQL_WCHAR should have value -8" + +def test_setencoding_with_constants(db_connection): + """Test setencoding using module constants.""" + import mssql_python + + # Test with SQL_CHAR constant + db_connection.setencoding(encoding='utf-8', ctype=mssql_python.SQL_CHAR) + settings = db_connection.getencoding() + assert settings['ctype'] == mssql_python.SQL_CHAR, "Should accept SQL_CHAR constant" + + # Test with SQL_WCHAR constant + db_connection.setencoding(encoding='utf-16le', ctype=mssql_python.SQL_WCHAR) + settings = db_connection.getencoding() + assert settings['ctype'] == mssql_python.SQL_WCHAR, "Should accept SQL_WCHAR constant" + +def test_setencoding_common_encodings(db_connection): + """Test setencoding with various common encodings.""" + common_encodings = [ + 'utf-8', + 'utf-16le', + 'utf-16be', + 'utf-16', + 'latin-1', + 'ascii', + 'cp1252' + ] + + for encoding in common_encodings: + try: + db_connection.setencoding(encoding=encoding) + settings = db_connection.getencoding() + assert settings['encoding'] == encoding, f"Failed to set encoding {encoding}" + except Exception as e: + pytest.fail(f"Failed to set valid encoding {encoding}: {e}") + +def test_setencoding_persistence_across_cursors(db_connection): + """Test that encoding settings persist across cursor operations.""" + # Set custom encoding + db_connection.setencoding(encoding='utf-8', ctype=1) + + # Create cursors and verify encoding persists + cursor1 = db_connection.cursor() + settings1 = db_connection.getencoding() + + cursor2 = db_connection.cursor() + settings2 = db_connection.getencoding() + + assert settings1 == settings2, "Encoding settings should persist across cursor creation" + assert settings1['encoding'] == 'utf-8', "Encoding should remain utf-8" + assert settings1['ctype'] == 1, "ctype should remain SQL_CHAR" + + cursor1.close() + cursor2.close() + +@pytest.mark.skip("Skipping Unicode data tests till we have support for Unicode") +def test_setencoding_with_unicode_data(db_connection): + """Test setencoding with actual Unicode data operations.""" + # Test UTF-8 encoding with Unicode data + db_connection.setencoding(encoding='utf-8') + cursor = db_connection.cursor() + + try: + # Create test table + cursor.execute("CREATE TABLE #test_encoding_unicode (text_col NVARCHAR(100))") + + # Test various Unicode strings + test_strings = [ + "Hello, World!", + "Hello, 世界!", # Chinese + "Привет, мир!", # Russian + "مرحبا بالعالم", # Arabic + "🌍🌎🌏", # Emoji + ] + + for test_string in test_strings: + # Insert data + cursor.execute("INSERT INTO #test_encoding_unicode (text_col) VALUES (?)", test_string) + + # Retrieve and verify + cursor.execute("SELECT text_col FROM #test_encoding_unicode WHERE text_col = ?", test_string) + result = cursor.fetchone() + + assert result is not None, f"Failed to retrieve Unicode string: {test_string}" + assert result[0] == test_string, f"Unicode string mismatch: expected {test_string}, got {result[0]}" + + # Clear for next test + cursor.execute("DELETE FROM #test_encoding_unicode") + + except Exception as e: + pytest.fail(f"Unicode data test failed with UTF-8 encoding: {e}") + finally: + try: + cursor.execute("DROP TABLE #test_encoding_unicode") + except: + pass + cursor.close() + +def test_setencoding_before_and_after_operations(db_connection): + """Test that setencoding works both before and after database operations.""" + cursor = db_connection.cursor() + + try: + # Initial encoding setting + db_connection.setencoding(encoding='utf-16le') + + # Perform database operation + cursor.execute("SELECT 'Initial test' as message") + result1 = cursor.fetchone() + assert result1[0] == 'Initial test', "Initial operation failed" + + # Change encoding after operation + db_connection.setencoding(encoding='utf-8') + settings = db_connection.getencoding() + assert settings['encoding'] == 'utf-8', "Failed to change encoding after operation" + + # Perform another operation with new encoding + cursor.execute("SELECT 'Changed encoding test' as message") + result2 = cursor.fetchone() + assert result2[0] == 'Changed encoding test', "Operation after encoding change failed" + + except Exception as e: + pytest.fail(f"Encoding change test failed: {e}") + finally: + cursor.close() + +def test_getencoding_returns_copy(db_connection): + """Test that getencoding returns a copy, not reference to internal data.""" + original_settings = db_connection.getencoding() + + # Modify the returned dictionary + original_settings['encoding'] = 'modified' + original_settings['ctype'] = 999 + + # Verify internal settings weren't affected + current_settings = db_connection.getencoding() + assert current_settings['encoding'] != 'modified', "getencoding should return a copy" + assert current_settings['ctype'] != 999, "getencoding should return a copy" + +def test_setencoding_thread_safety(conn_str): + """Test setencoding behavior with multiple connections (thread safety indication).""" + import threading + + def worker(connection_str, encoding, results, index): + try: + conn = connect(connection_str) + conn.setencoding(encoding=encoding) + settings = conn.getencoding() + results[index] = settings['encoding'] + conn.close() + except Exception as e: + results[index] = f"Error: {e}" + + # Test with multiple threads setting different encodings + results = [None] * 3 + threads = [] + encodings = ['utf-8', 'utf-16le', 'latin-1'] + + for i, encoding in enumerate(encodings): + thread = threading.Thread(target=worker, args=(conn_str, encoding, results, i)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + # Verify each connection got its own encoding setting + for i, expected_encoding in enumerate(encodings): + assert results[i] == expected_encoding, f"Thread {i} failed to set encoding {expected_encoding}: {results[i]}" + +def test_setencoding_parameter_validation_edge_cases(db_connection): + """Test edge cases for parameter validation.""" + from mssql_python.exceptions import ProgrammingError + + # Test empty string encoding + with pytest.raises(ProgrammingError): + db_connection.setencoding(encoding='') + + # Test non-string encoding (should be handled gracefully or raise appropriate error) + with pytest.raises((ProgrammingError, TypeError)): + db_connection.setencoding(encoding=123) + + # Test non-integer ctype + with pytest.raises((ProgrammingError, TypeError)): + db_connection.setencoding(encoding='utf-8', ctype='invalid') + +def test_setencoding_case_sensitivity(db_connection): + """Test encoding name case sensitivity.""" + # Most Python codecs are case-insensitive, but test common variations + case_variations = [ + ('utf-8', 'UTF-8'), + ('utf-16le', 'UTF-16LE'), + ('latin-1', 'LATIN-1'), + ('ascii', 'ASCII') + ] + + for lower, upper in case_variations: + try: + # Test lowercase + db_connection.setencoding(encoding=lower) + settings_lower = db_connection.getencoding() + + # Test uppercase + db_connection.setencoding(encoding=upper) + settings_upper = db_connection.getencoding() + + # Both should work (Python codecs are generally case-insensitive) + assert settings_lower['encoding'] == lower, f"Failed to set {lower}" + assert settings_upper['encoding'] == upper, f"Failed to set {upper}" + + except Exception as e: + # If one variant fails, both should fail consistently + with pytest.raises(type(e)): + db_connection.setencoding(encoding=lower if encoding == upper else upper) \ No newline at end of file From 751b0b8cfe0fa6264da13f705cbc7e9991291185 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Thu, 7 Aug 2025 12:43:03 +0530 Subject: [PATCH 2/5] Adding init.py --- mssql_python/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mssql_python/__init__.py b/mssql_python/__init__.py index 6bf957779..8e118cd5b 100644 --- a/mssql_python/__init__.py +++ b/mssql_python/__init__.py @@ -47,6 +47,10 @@ # Constants from .constants import ConstantsDDBC +# Export specific constants for setencoding() +SQL_CHAR = ConstantsDDBC.SQL_CHAR.value +SQL_WCHAR = ConstantsDDBC.SQL_WCHAR.value + # GLOBALS # Read-Only apilevel = "2.0" @@ -71,4 +75,3 @@ def pooling(max_size=100, idle_timeout=600, enabled=True): PoolingManager.disable() else: PoolingManager.enable(max_size, idle_timeout) - \ No newline at end of file From 600c1135ffbcb5cdd4d3da1fd20a57c8bd7cdf35 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Thu, 7 Aug 2025 14:46:51 +0530 Subject: [PATCH 3/5] Resolving comments --- mssql_python/connection.py | 73 +++++++-- mssql_python/helpers.py | 28 ++++ tests/test_003_connection.py | 284 +++++++++++++++++++++++++---------- 3 files changed, 292 insertions(+), 93 deletions(-) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index b68fd75e1..f2b6d5ba5 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -13,14 +13,49 @@ import weakref import re import codecs +from functools import lru_cache from mssql_python.cursor import Cursor -from mssql_python.helpers import add_driver_to_connection_str, sanitize_connection_string, log +from mssql_python.helpers import add_driver_to_connection_str, sanitize_connection_string, sanitize_user_input, log from mssql_python import ddbc_bindings from mssql_python.pooling import PoolingManager from mssql_python.exceptions import InterfaceError, ProgrammingError from mssql_python.auth import process_connection_string from mssql_python.constants import ConstantsDDBC +# UTF-16 encoding variants that should use SQL_WCHAR by default +UTF16_ENCODINGS = frozenset([ + 'utf-16', + 'utf-16le', + 'utf-16be' +]) + +# Cache for encoding validation to improve performance +# Using a simple dict instead of lru_cache for module-level caching +_ENCODING_VALIDATION_CACHE = {} +_CACHE_MAX_SIZE = 100 # Limit cache size to prevent memory bloat + + +@lru_cache(maxsize=128) +def _validate_encoding(encoding: str) -> bool: + """ + Cached encoding validation using codecs.lookup(). + + Args: + encoding (str): The encoding name to validate. + + Returns: + bool: True if encoding is valid, False otherwise. + + Note: + Uses LRU cache to avoid repeated expensive codecs.lookup() calls. + Cache size is limited to 128 entries which should cover most use cases. + """ + try: + codecs.lookup(encoding) + return True + except LookupError: + return False + class Connection: """ @@ -181,7 +216,7 @@ def setencoding(self, encoding=None, ctype=None): encoding that converts text to bytes. If None, defaults to 'utf-16le'. ctype (int, optional): The C data type to use when passing data: SQL_CHAR or SQL_WCHAR. If not provided, SQL_WCHAR is used for - "utf-16", "utf-16le", and "utf-16be". SQL_CHAR is used for all other encodings. + UTF-16 variants (see UTF16_ENCODINGS constant). SQL_CHAR is used for all other encodings. Returns: None @@ -199,26 +234,29 @@ def setencoding(self, encoding=None, ctype=None): """ if self._closed: raise InterfaceError( - driver_error="Cannot set encoding on closed connection", - ddbc_error="Cannot set encoding on closed connection", + driver_error="Connection is closed", + ddbc_error="Connection is closed", ) # Set default encoding if not provided if encoding is None: encoding = 'utf-16le' - # Validate encoding - try: - codecs.lookup(encoding) - except LookupError: + # Validate encoding using cached validation for better performance + if not _validate_encoding(encoding): + # Log the sanitized encoding for security + log('warning', "Invalid encoding attempted: %s", sanitize_user_input(str(encoding))) raise ProgrammingError( - driver_error=f"Unknown encoding: {encoding}", + driver_error=f"Unsupported encoding: {encoding}", ddbc_error=f"The encoding '{encoding}' is not supported by Python", ) + # Normalize encoding to lowercase for consistency + encoding = encoding.lower() + # Set default ctype based on encoding if not provided if ctype is None: - if encoding.lower() in ('utf-16', 'utf-16le', 'utf-16be'): + if encoding in UTF16_ENCODINGS: ctype = ConstantsDDBC.SQL_WCHAR.value else: ctype = ConstantsDDBC.SQL_CHAR.value @@ -226,6 +264,8 @@ def setencoding(self, encoding=None, ctype=None): # Validate ctype valid_ctypes = [ConstantsDDBC.SQL_CHAR.value, ConstantsDDBC.SQL_WCHAR.value] if ctype not in valid_ctypes: + # Log the sanitized ctype for security + log('warning', "Invalid ctype attempted: %s", sanitize_user_input(str(ctype))) raise ProgrammingError( driver_error=f"Invalid ctype: {ctype}", ddbc_error=f"ctype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}) or SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value})", @@ -237,7 +277,9 @@ def setencoding(self, encoding=None, ctype=None): 'ctype': ctype } - log('info', "Text encoding set to %s with ctype %s", encoding, ctype) + # Log with sanitized values for security + log('info', "Text encoding set to %s with ctype %s", + sanitize_user_input(encoding), sanitize_user_input(str(ctype))) def getencoding(self): """ @@ -246,11 +288,20 @@ def getencoding(self): Returns: dict: A dictionary containing 'encoding' and 'ctype' keys. + Raises: + InterfaceError: If the connection is closed. + Example: settings = cnxn.getencoding() print(f"Current encoding: {settings['encoding']}") print(f"Current ctype: {settings['ctype']}") """ + if self._closed: + raise InterfaceError( + driver_error="Connection is closed", + ddbc_error="Connection is closed", + ) + return self._encoding_settings.copy() def cursor(self) -> Cursor: diff --git a/mssql_python/helpers.py b/mssql_python/helpers.py index f15365c99..2ac3c6694 100644 --- a/mssql_python/helpers.py +++ b/mssql_python/helpers.py @@ -128,6 +128,34 @@ def sanitize_connection_string(conn_str: str) -> str: return re.sub(r"(Pwd\s*=\s*)[^;]*", r"\1***", conn_str, flags=re.IGNORECASE) +def sanitize_user_input(user_input: str, max_length: int = 50) -> str: + """ + Sanitize user input for safe logging by removing control characters, + limiting length, and ensuring safe characters only. + + Args: + user_input (str): The user input to sanitize. + max_length (int): Maximum length of the sanitized output. + + Returns: + str: The sanitized string safe for logging. + """ + if not isinstance(user_input, str): + return "" + + # Remove control characters and non-printable characters + import re + # Allow alphanumeric, dash, underscore, and dot (common in encoding names) + sanitized = re.sub(r'[^\w\-\.]', '', user_input) + + # Limit length to prevent log flooding + if len(sanitized) > max_length: + sanitized = sanitized[:max_length] + "..." + + # Return placeholder if nothing remains after sanitization + return sanitized if sanitized else "" + + def log(level: str, message: str, *args) -> None: """ Universal logging helper that gets a fresh logger instance. diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index 4fb6d3e9b..30b08e62d 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -21,7 +21,7 @@ from mssql_python.exceptions import InterfaceError import pytest import time -from mssql_python import Connection, connect, pooling +from mssql_python import connect, Connection, pooling, SQL_CHAR, SQL_WCHAR import threading def drop_table_if_exists(cursor, table_name): @@ -713,91 +713,211 @@ def test_setencoding_before_and_after_operations(db_connection): finally: cursor.close() -def test_getencoding_returns_copy(db_connection): - """Test that getencoding returns a copy, not reference to internal data.""" - original_settings = db_connection.getencoding() - - # Modify the returned dictionary - original_settings['encoding'] = 'modified' - original_settings['ctype'] = 999 - - # Verify internal settings weren't affected - current_settings = db_connection.getencoding() - assert current_settings['encoding'] != 'modified', "getencoding should return a copy" - assert current_settings['ctype'] != 999, "getencoding should return a copy" +def test_getencoding_default(conn_str): + """Test getencoding returns default settings""" + conn = connect(conn_str) + try: + encoding_info = conn.getencoding() + assert isinstance(encoding_info, dict) + assert 'encoding' in encoding_info + assert 'ctype' in encoding_info + # Default should be utf-16le with SQL_WCHAR + assert encoding_info['encoding'] == 'utf-16le' + assert encoding_info['ctype'] == SQL_WCHAR + finally: + conn.close() -def test_setencoding_thread_safety(conn_str): - """Test setencoding behavior with multiple connections (thread safety indication).""" - import threading - - def worker(connection_str, encoding, results, index): - try: - conn = connect(connection_str) - conn.setencoding(encoding=encoding) - settings = conn.getencoding() - results[index] = settings['encoding'] - conn.close() - except Exception as e: - results[index] = f"Error: {e}" - - # Test with multiple threads setting different encodings - results = [None] * 3 - threads = [] - encodings = ['utf-8', 'utf-16le', 'latin-1'] - - for i, encoding in enumerate(encodings): - thread = threading.Thread(target=worker, args=(conn_str, encoding, results, i)) - threads.append(thread) - thread.start() +def test_getencoding_returns_copy(conn_str): + """Test getencoding returns a copy (not reference)""" + conn = connect(conn_str) + try: + encoding_info1 = conn.getencoding() + encoding_info2 = conn.getencoding() + + # Should be equal but not the same object + assert encoding_info1 == encoding_info2 + assert encoding_info1 is not encoding_info2 + + # Modifying one shouldn't affect the other + encoding_info1['encoding'] = 'modified' + assert encoding_info2['encoding'] != 'modified' + finally: + conn.close() + +def test_getencoding_closed_connection(conn_str): + """Test getencoding on closed connection raises InterfaceError""" + conn = connect(conn_str) + conn.close() - for thread in threads: - thread.join() + with pytest.raises(InterfaceError, match="Connection is closed"): + conn.getencoding() + +def test_setencoding_getencoding_consistency(conn_str): + """Test that setencoding and getencoding work consistently together""" + conn = connect(conn_str) + try: + test_cases = [ + ('utf-8', SQL_CHAR), + ('utf-16le', SQL_WCHAR), + ('latin-1', SQL_CHAR), + ('ascii', SQL_CHAR), + ] + + for encoding, expected_ctype in test_cases: + conn.setencoding(encoding) + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == encoding.lower() + assert encoding_info['ctype'] == expected_ctype + finally: + conn.close() + +def test_setencoding_default_encoding(conn_str): + """Test setencoding with default UTF-16LE encoding""" + conn = connect(conn_str) + try: + conn.setencoding() + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'utf-16le' + assert encoding_info['ctype'] == SQL_WCHAR + finally: + conn.close() + +def test_setencoding_utf8(conn_str): + """Test setencoding with UTF-8 encoding""" + conn = connect(conn_str) + try: + conn.setencoding('utf-8') + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'utf-8' + assert encoding_info['ctype'] == SQL_CHAR + finally: + conn.close() + +def test_setencoding_latin1(conn_str): + """Test setencoding with latin-1 encoding""" + conn = connect(conn_str) + try: + conn.setencoding('latin-1') + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'latin-1' + assert encoding_info['ctype'] == SQL_CHAR + finally: + conn.close() + +def test_setencoding_with_explicit_ctype_sql_char(conn_str): + """Test setencoding with explicit SQL_CHAR ctype""" + conn = connect(conn_str) + try: + conn.setencoding('utf-8', SQL_CHAR) + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'utf-8' + assert encoding_info['ctype'] == SQL_CHAR + finally: + conn.close() + +def test_setencoding_with_explicit_ctype_sql_wchar(conn_str): + """Test setencoding with explicit SQL_WCHAR ctype""" + conn = connect(conn_str) + try: + conn.setencoding('utf-16le', SQL_WCHAR) + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'utf-16le' + assert encoding_info['ctype'] == SQL_WCHAR + finally: + conn.close() + +def test_setencoding_invalid_encoding(conn_str): + """Test setencoding with invalid encoding raises ProgrammingError""" + from mssql_python.exceptions import ProgrammingError - # Verify each connection got its own encoding setting - for i, expected_encoding in enumerate(encodings): - assert results[i] == expected_encoding, f"Thread {i} failed to set encoding {expected_encoding}: {results[i]}" + conn = connect(conn_str) + try: + with pytest.raises(ProgrammingError, match="Unsupported encoding"): + conn.setencoding('invalid-encoding-name') + finally: + conn.close() -def test_setencoding_parameter_validation_edge_cases(db_connection): - """Test edge cases for parameter validation.""" +def test_setencoding_invalid_ctype(conn_str): + """Test setencoding with invalid ctype raises ProgrammingError""" from mssql_python.exceptions import ProgrammingError - # Test empty string encoding - with pytest.raises(ProgrammingError): - db_connection.setencoding(encoding='') - - # Test non-string encoding (should be handled gracefully or raise appropriate error) - with pytest.raises((ProgrammingError, TypeError)): - db_connection.setencoding(encoding=123) - - # Test non-integer ctype - with pytest.raises((ProgrammingError, TypeError)): - db_connection.setencoding(encoding='utf-8', ctype='invalid') - -def test_setencoding_case_sensitivity(db_connection): - """Test encoding name case sensitivity.""" - # Most Python codecs are case-insensitive, but test common variations - case_variations = [ - ('utf-8', 'UTF-8'), - ('utf-16le', 'UTF-16LE'), - ('latin-1', 'LATIN-1'), - ('ascii', 'ASCII') - ] + conn = connect(conn_str) + try: + with pytest.raises(ProgrammingError, match="Invalid ctype"): + conn.setencoding('utf-8', 999) + finally: + conn.close() + +def test_setencoding_closed_connection(conn_str): + """Test setencoding on closed connection raises InterfaceError""" + conn = connect(conn_str) + conn.close() - for lower, upper in case_variations: - try: - # Test lowercase - db_connection.setencoding(encoding=lower) - settings_lower = db_connection.getencoding() - - # Test uppercase - db_connection.setencoding(encoding=upper) - settings_upper = db_connection.getencoding() - - # Both should work (Python codecs are generally case-insensitive) - assert settings_lower['encoding'] == lower, f"Failed to set {lower}" - assert settings_upper['encoding'] == upper, f"Failed to set {upper}" - - except Exception as e: - # If one variant fails, both should fail consistently - with pytest.raises(type(e)): - db_connection.setencoding(encoding=lower if encoding == upper else upper) \ No newline at end of file + with pytest.raises(InterfaceError, match="Connection is closed"): + conn.setencoding('utf-8') + +def test_setencoding_case_insensitive_encoding(conn_str): + """Test setencoding with case variations""" + conn = connect(conn_str) + try: + # Test various case formats + conn.setencoding('UTF-8') + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'utf-8' # Should be normalized + + conn.setencoding('Utf-16LE') + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'utf-16le' # Should be normalized + finally: + conn.close() + +def test_setencoding_none_encoding_default(conn_str): + """Test setencoding with None encoding uses default""" + conn = connect(conn_str) + try: + conn.setencoding(None) + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'utf-16le' + assert encoding_info['ctype'] == SQL_WCHAR + finally: + conn.close() + +def test_setencoding_override_previous(conn_str): + """Test setencoding overrides previous settings""" + conn = connect(conn_str) + try: + # Set initial encoding + conn.setencoding('utf-8') + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'utf-8' + assert encoding_info['ctype'] == SQL_CHAR + + # Override with different encoding + conn.setencoding('utf-16le') + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'utf-16le' + assert encoding_info['ctype'] == SQL_WCHAR + finally: + conn.close() + +def test_setencoding_ascii(conn_str): + """Test setencoding with ASCII encoding""" + conn = connect(conn_str) + try: + conn.setencoding('ascii') + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'ascii' + assert encoding_info['ctype'] == SQL_CHAR + finally: + conn.close() + +def test_setencoding_cp1252(conn_str): + """Test setencoding with Windows-1252 encoding""" + conn = connect(conn_str) + try: + conn.setencoding('cp1252') + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'cp1252' + assert encoding_info['ctype'] == SQL_CHAR + finally: + conn.close() \ No newline at end of file From 15b468c7d9aa48fc02d18126ace7fa46a32ca8b3 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Tue, 26 Aug 2025 16:37:51 +0530 Subject: [PATCH 4/5] Resolving comments --- mssql_python/connection.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index f2b6d5ba5..9b19a603a 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -13,7 +13,6 @@ import weakref import re import codecs -from functools import lru_cache from mssql_python.cursor import Cursor from mssql_python.helpers import add_driver_to_connection_str, sanitize_connection_string, sanitize_user_input, log from mssql_python import ddbc_bindings @@ -29,13 +28,7 @@ 'utf-16be' ]) -# Cache for encoding validation to improve performance -# Using a simple dict instead of lru_cache for module-level caching -_ENCODING_VALIDATION_CACHE = {} -_CACHE_MAX_SIZE = 100 # Limit cache size to prevent memory bloat - -@lru_cache(maxsize=128) def _validate_encoding(encoding: str) -> bool: """ Cached encoding validation using codecs.lookup(). @@ -251,8 +244,8 @@ def setencoding(self, encoding=None, ctype=None): ddbc_error=f"The encoding '{encoding}' is not supported by Python", ) - # Normalize encoding to lowercase for consistency - encoding = encoding.lower() + # Normalize encoding to casefold for more robust Unicode handling + encoding = encoding.casefold() # Set default ctype based on encoding if not provided if ctype is None: From afcdbf31140df0c93a7220eb8067245f5c0075d6 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar <61936179+jahnvi480@users.noreply.github.com> Date: Wed, 27 Aug 2025 17:18:55 +0530 Subject: [PATCH 5/5] FEAT: Adding setdecoding() API for connection (#173) ### Work Item / Issue Reference > [AB#34918](https://sqlclientdrivers.visualstudio.com/c6d89619-62de-46a0-8b46-70b92a84d85e/_workitems/edit/34918) ------------------------------------------------------------------- ### Summary This pull request introduces new functionality for configuring and retrieving text decoding settings in the `Connection` class of the `mssql_python` package. The main changes add support for a new special SQL type flag (`SQL_WMETADATA`) to allow explicit control over how column metadata is decoded, and provide two new methods (`setdecoding` and `getdecoding`) for managing decoding configuration per SQL type. **Enhancements to decoding configuration:** * Added a new constant, `SQL_WMETADATA`, in both `mssql_python/__init__.py` and `mssql_python/connection.py`, to allow explicit configuration of column name decoding. * Initialized a `_decoding_settings` dictionary in the `Connection` class to store decoding settings for `SQL_CHAR`, `SQL_WCHAR`, and `SQL_WMETADATA`, with sensible Python 3 defaults. * Introduced the `setdecoding` method to the `Connection` class, allowing users to configure the decoding (encoding and ctype) for each SQL type, including validation and error handling. * Added the `getdecoding` method to the `Connection` class, enabling retrieval of the current decoding settings for a specific SQL type, with validation and error handling. **Testing configuration:** * Updated the `conn_str` fixture in `tests/conftest.py` to use a hardcoded connection string, likely for local testing purposes. --------- Co-authored-by: Jahnvi Thakkar --- mssql_python/__init__.py | 1 + mssql_python/connection.py | 163 ++++++++++++- tests/test_003_connection.py | 461 ++++++++++++++++++++++++++++++++--- 3 files changed, 596 insertions(+), 29 deletions(-) diff --git a/mssql_python/__init__.py b/mssql_python/__init__.py index 8e118cd5b..071136462 100644 --- a/mssql_python/__init__.py +++ b/mssql_python/__init__.py @@ -50,6 +50,7 @@ # Export specific constants for setencoding() SQL_CHAR = ConstantsDDBC.SQL_CHAR.value SQL_WCHAR = ConstantsDDBC.SQL_WCHAR.value +SQL_WMETADATA = -99 # GLOBALS # Read-Only diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 9b19a603a..187f3e33e 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -21,6 +21,9 @@ from mssql_python.auth import process_connection_string from mssql_python.constants import ConstantsDDBC +# Add SQL_WMETADATA constant for metadata decoding configuration +SQL_WMETADATA = -99 # Special flag for column name decoding + # UTF-16 encoding variants that should use SQL_WCHAR by default UTF16_ENCODINGS = frozenset([ 'utf-16', @@ -28,7 +31,6 @@ 'utf-16be' ]) - def _validate_encoding(encoding: str) -> bool: """ Cached encoding validation using codecs.lookup(). @@ -67,6 +69,8 @@ class Connection: rollback() -> None: close() -> None: setencoding(encoding=None, ctype=None) -> None: + setdecoding(sqltype, encoding=None, ctype=None) -> None: + getdecoding(sqltype) -> dict: """ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_before: dict = None, **kwargs) -> None: @@ -101,6 +105,22 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef 'ctype': ConstantsDDBC.SQL_WCHAR.value } + # Initialize decoding settings with Python 3 defaults + self._decoding_settings = { + ConstantsDDBC.SQL_CHAR.value: { + 'encoding': 'utf-8', + 'ctype': ConstantsDDBC.SQL_CHAR.value + }, + ConstantsDDBC.SQL_WCHAR.value: { + 'encoding': 'utf-16le', + 'ctype': ConstantsDDBC.SQL_WCHAR.value + }, + SQL_WMETADATA: { + 'encoding': 'utf-16le', + 'ctype': ConstantsDDBC.SQL_WCHAR.value + } + } + # Check if the connection string contains authentication parameters # This is important for processing the connection string correctly. # If authentication is specified, it will be processed to handle @@ -297,6 +317,147 @@ def getencoding(self): return self._encoding_settings.copy() + def setdecoding(self, sqltype, encoding=None, ctype=None): + """ + Sets the text decoding used when reading SQL_CHAR and SQL_WCHAR from the database. + + This method configures how text data is decoded when reading from the database. + In Python 3, all text is Unicode (str), so this primarily affects the encoding + used to decode bytes from the database. + + Args: + sqltype (int): The SQL type being configured: SQL_CHAR, SQL_WCHAR, or SQL_WMETADATA. + SQL_WMETADATA is a special flag for configuring column name decoding. + encoding (str, optional): The Python encoding to use when decoding the data. + If None, uses default encoding based on sqltype. + ctype (int, optional): The C data type to request from SQLGetData: + SQL_CHAR or SQL_WCHAR. If None, uses default based on encoding. + + Returns: + None + + Raises: + ProgrammingError: If the sqltype, encoding, or ctype is invalid. + InterfaceError: If the connection is closed. + + Example: + # Configure SQL_CHAR to use UTF-8 decoding + cnxn.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') + + # Configure column metadata decoding + cnxn.setdecoding(mssql_python.SQL_WMETADATA, encoding='utf-16le') + + # Use explicit ctype + cnxn.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16le', ctype=mssql_python.SQL_WCHAR) + """ + if self._closed: + raise InterfaceError( + driver_error="Connection is closed", + ddbc_error="Connection is closed", + ) + + # Validate sqltype + valid_sqltypes = [ + ConstantsDDBC.SQL_CHAR.value, + ConstantsDDBC.SQL_WCHAR.value, + SQL_WMETADATA + ] + if sqltype not in valid_sqltypes: + log('warning', "Invalid sqltype attempted: %s", sanitize_user_input(str(sqltype))) + raise ProgrammingError( + driver_error=f"Invalid sqltype: {sqltype}", + ddbc_error=f"sqltype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}), SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value}), or SQL_WMETADATA ({SQL_WMETADATA})", + ) + + # Set default encoding based on sqltype if not provided + if encoding is None: + if sqltype == ConstantsDDBC.SQL_CHAR.value: + encoding = 'utf-8' # Default for SQL_CHAR in Python 3 + else: # SQL_WCHAR or SQL_WMETADATA + encoding = 'utf-16le' # Default for SQL_WCHAR in Python 3 + + # Validate encoding using cached validation for better performance + if not _validate_encoding(encoding): + log('warning', "Invalid encoding attempted: %s", sanitize_user_input(str(encoding))) + raise ProgrammingError( + driver_error=f"Unsupported encoding: {encoding}", + ddbc_error=f"The encoding '{encoding}' is not supported by Python", + ) + + # Normalize encoding to lowercase for consistency + encoding = encoding.lower() + + # Set default ctype based on encoding if not provided + if ctype is None: + if encoding in UTF16_ENCODINGS: + ctype = ConstantsDDBC.SQL_WCHAR.value + else: + ctype = ConstantsDDBC.SQL_CHAR.value + + # Validate ctype + valid_ctypes = [ConstantsDDBC.SQL_CHAR.value, ConstantsDDBC.SQL_WCHAR.value] + if ctype not in valid_ctypes: + log('warning', "Invalid ctype attempted: %s", sanitize_user_input(str(ctype))) + raise ProgrammingError( + driver_error=f"Invalid ctype: {ctype}", + ddbc_error=f"ctype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}) or SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value})", + ) + + # Store the decoding settings for the specified sqltype + self._decoding_settings[sqltype] = { + 'encoding': encoding, + 'ctype': ctype + } + + # Log with sanitized values for security + sqltype_name = { + ConstantsDDBC.SQL_CHAR.value: "SQL_CHAR", + ConstantsDDBC.SQL_WCHAR.value: "SQL_WCHAR", + SQL_WMETADATA: "SQL_WMETADATA" + }.get(sqltype, str(sqltype)) + + log('info', "Text decoding set for %s to %s with ctype %s", + sqltype_name, sanitize_user_input(encoding), sanitize_user_input(str(ctype))) + + def getdecoding(self, sqltype): + """ + Gets the current text decoding settings for the specified SQL type. + + Args: + sqltype (int): The SQL type to get settings for: SQL_CHAR, SQL_WCHAR, or SQL_WMETADATA. + + Returns: + dict: A dictionary containing 'encoding' and 'ctype' keys for the specified sqltype. + + Raises: + ProgrammingError: If the sqltype is invalid. + InterfaceError: If the connection is closed. + + Example: + settings = cnxn.getdecoding(mssql_python.SQL_CHAR) + print(f"SQL_CHAR encoding: {settings['encoding']}") + print(f"SQL_CHAR ctype: {settings['ctype']}") + """ + if self._closed: + raise InterfaceError( + driver_error="Connection is closed", + ddbc_error="Connection is closed", + ) + + # Validate sqltype + valid_sqltypes = [ + ConstantsDDBC.SQL_CHAR.value, + ConstantsDDBC.SQL_WCHAR.value, + SQL_WMETADATA + ] + if sqltype not in valid_sqltypes: + raise ProgrammingError( + driver_error=f"Invalid sqltype: {sqltype}", + ddbc_error=f"sqltype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}), SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value}), or SQL_WMETADATA ({SQL_WMETADATA})", + ) + + return self._decoding_settings[sqltype].copy() + def cursor(self) -> Cursor: """ Return a new Cursor object using the connection. diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index c9d043d38..74cccdc92 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -18,7 +18,8 @@ - test_rollback_on_close: Test that rollback occurs on connection close if autocommit is False. """ -from mssql_python.exceptions import InterfaceError +from mssql_python.exceptions import InterfaceError, ProgrammingError +import mssql_python import pytest import time from mssql_python import connect, Connection, pooling, SQL_CHAR, SQL_WCHAR @@ -552,17 +553,15 @@ def test_setencoding_none_parameters(db_connection): def test_setencoding_invalid_encoding(db_connection): """Test setencoding with invalid encoding.""" - from mssql_python.exceptions import ProgrammingError with pytest.raises(ProgrammingError) as exc_info: db_connection.setencoding(encoding='invalid-encoding-name') - assert "Unknown encoding" in str(exc_info.value), "Should raise ProgrammingError for invalid encoding" + assert "Unsupported encoding" in str(exc_info.value), "Should raise ProgrammingError for invalid encoding" assert "invalid-encoding-name" in str(exc_info.value), "Error message should include the invalid encoding name" def test_setencoding_invalid_ctype(db_connection): """Test setencoding with invalid ctype.""" - from mssql_python.exceptions import ProgrammingError with pytest.raises(ProgrammingError) as exc_info: db_connection.setencoding(encoding='utf-8', ctype=999) @@ -572,7 +571,6 @@ def test_setencoding_invalid_ctype(db_connection): def test_setencoding_closed_connection(conn_str): """Test setencoding on closed connection.""" - from mssql_python.exceptions import InterfaceError temp_conn = connect(conn_str) temp_conn.close() @@ -580,7 +578,7 @@ def test_setencoding_closed_connection(conn_str): with pytest.raises(InterfaceError) as exc_info: temp_conn.setencoding(encoding='utf-8') - assert "closed connection" in str(exc_info.value).lower(), "Should raise InterfaceError for closed connection" + assert "Connection is closed" in str(exc_info.value), "Should raise InterfaceError for closed connection" def test_setencoding_constants_access(): """Test that SQL_CHAR and SQL_WCHAR constants are accessible.""" @@ -829,20 +827,8 @@ def test_setencoding_with_explicit_ctype_sql_wchar(conn_str): finally: conn.close() -def test_setencoding_invalid_encoding(conn_str): - """Test setencoding with invalid encoding raises ProgrammingError""" - from mssql_python.exceptions import ProgrammingError - - conn = connect(conn_str) - try: - with pytest.raises(ProgrammingError, match="Unsupported encoding"): - conn.setencoding('invalid-encoding-name') - finally: - conn.close() - -def test_setencoding_invalid_ctype(conn_str): +def test_setencoding_invalid_ctype_error(conn_str): """Test setencoding with invalid ctype raises ProgrammingError""" - from mssql_python.exceptions import ProgrammingError conn = connect(conn_str) try: @@ -851,14 +837,6 @@ def test_setencoding_invalid_ctype(conn_str): finally: conn.close() -def test_setencoding_closed_connection(conn_str): - """Test setencoding on closed connection raises InterfaceError""" - conn = connect(conn_str) - conn.close() - - with pytest.raises(InterfaceError, match="Connection is closed"): - conn.setencoding('utf-8') - def test_setencoding_case_insensitive_encoding(conn_str): """Test setencoding with case variations""" conn = connect(conn_str) @@ -923,4 +901,431 @@ def test_setencoding_cp1252(conn_str): assert encoding_info['encoding'] == 'cp1252' assert encoding_info['ctype'] == SQL_CHAR finally: - conn.close() \ No newline at end of file + conn.close() + +def test_setdecoding_default_settings(db_connection): + """Test that default decoding settings are correct for all SQL types.""" + + # Check SQL_CHAR defaults + sql_char_settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert sql_char_settings['encoding'] == 'utf-8', "Default SQL_CHAR encoding should be utf-8" + assert sql_char_settings['ctype'] == mssql_python.SQL_CHAR, "Default SQL_CHAR ctype should be SQL_CHAR" + + # Check SQL_WCHAR defaults + sql_wchar_settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert sql_wchar_settings['encoding'] == 'utf-16le', "Default SQL_WCHAR encoding should be utf-16le" + assert sql_wchar_settings['ctype'] == mssql_python.SQL_WCHAR, "Default SQL_WCHAR ctype should be SQL_WCHAR" + + # Check SQL_WMETADATA defaults + sql_wmetadata_settings = db_connection.getdecoding(mssql_python.SQL_WMETADATA) + assert sql_wmetadata_settings['encoding'] == 'utf-16le', "Default SQL_WMETADATA encoding should be utf-16le" + assert sql_wmetadata_settings['ctype'] == mssql_python.SQL_WCHAR, "Default SQL_WMETADATA ctype should be SQL_WCHAR" + +def test_setdecoding_basic_functionality(db_connection): + """Test basic setdecoding functionality for different SQL types.""" + + # Test setting SQL_CHAR decoding + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='latin-1') + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'latin-1', "SQL_CHAR encoding should be set to latin-1" + assert settings['ctype'] == mssql_python.SQL_CHAR, "SQL_CHAR ctype should default to SQL_CHAR for latin-1" + + # Test setting SQL_WCHAR decoding + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16be') + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['encoding'] == 'utf-16be', "SQL_WCHAR encoding should be set to utf-16be" + assert settings['ctype'] == mssql_python.SQL_WCHAR, "SQL_WCHAR ctype should default to SQL_WCHAR for utf-16be" + + # Test setting SQL_WMETADATA decoding + db_connection.setdecoding(mssql_python.SQL_WMETADATA, encoding='utf-16le') + settings = db_connection.getdecoding(mssql_python.SQL_WMETADATA) + assert settings['encoding'] == 'utf-16le', "SQL_WMETADATA encoding should be set to utf-16le" + assert settings['ctype'] == mssql_python.SQL_WCHAR, "SQL_WMETADATA ctype should default to SQL_WCHAR" + +def test_setdecoding_automatic_ctype_detection(db_connection): + """Test automatic ctype detection based on encoding for different SQL types.""" + + # UTF-16 variants should default to SQL_WCHAR + utf16_encodings = ['utf-16', 'utf-16le', 'utf-16be'] + for encoding in utf16_encodings: + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding=encoding) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['ctype'] == mssql_python.SQL_WCHAR, f"SQL_CHAR with {encoding} should auto-detect SQL_WCHAR ctype" + + # Other encodings should default to SQL_CHAR + other_encodings = ['utf-8', 'latin-1', 'ascii', 'cp1252'] + for encoding in other_encodings: + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding=encoding) + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['ctype'] == mssql_python.SQL_CHAR, f"SQL_WCHAR with {encoding} should auto-detect SQL_CHAR ctype" + +def test_setdecoding_explicit_ctype_override(db_connection): + """Test that explicit ctype parameter overrides automatic detection.""" + + # Set SQL_CHAR with UTF-8 encoding but explicit SQL_WCHAR ctype + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8', ctype=mssql_python.SQL_WCHAR) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'utf-8', "Encoding should be utf-8" + assert settings['ctype'] == mssql_python.SQL_WCHAR, "ctype should be SQL_WCHAR when explicitly set" + + # Set SQL_WCHAR with UTF-16LE encoding but explicit SQL_CHAR ctype + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16le', ctype=mssql_python.SQL_CHAR) + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['encoding'] == 'utf-16le', "Encoding should be utf-16le" + assert settings['ctype'] == mssql_python.SQL_CHAR, "ctype should be SQL_CHAR when explicitly set" + +def test_setdecoding_none_parameters(db_connection): + """Test setdecoding with None parameters uses appropriate defaults.""" + + # Test SQL_CHAR with encoding=None (should use utf-8 default) + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding=None) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'utf-8', "SQL_CHAR with encoding=None should use utf-8 default" + assert settings['ctype'] == mssql_python.SQL_CHAR, "ctype should be SQL_CHAR for utf-8" + + # Test SQL_WCHAR with encoding=None (should use utf-16le default) + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding=None) + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['encoding'] == 'utf-16le', "SQL_WCHAR with encoding=None should use utf-16le default" + assert settings['ctype'] == mssql_python.SQL_WCHAR, "ctype should be SQL_WCHAR for utf-16le" + + # Test with both parameters None + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding=None, ctype=None) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'utf-8', "SQL_CHAR with both None should use utf-8 default" + assert settings['ctype'] == mssql_python.SQL_CHAR, "ctype should default to SQL_CHAR" + +def test_setdecoding_invalid_sqltype(db_connection): + """Test setdecoding with invalid sqltype raises ProgrammingError.""" + + with pytest.raises(ProgrammingError) as exc_info: + db_connection.setdecoding(999, encoding='utf-8') + + assert "Invalid sqltype" in str(exc_info.value), "Should raise ProgrammingError for invalid sqltype" + assert "999" in str(exc_info.value), "Error message should include the invalid sqltype value" + +def test_setdecoding_invalid_encoding(db_connection): + """Test setdecoding with invalid encoding raises ProgrammingError.""" + + with pytest.raises(ProgrammingError) as exc_info: + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='invalid-encoding-name') + + assert "Unsupported encoding" in str(exc_info.value), "Should raise ProgrammingError for invalid encoding" + assert "invalid-encoding-name" in str(exc_info.value), "Error message should include the invalid encoding name" + +def test_setdecoding_invalid_ctype(db_connection): + """Test setdecoding with invalid ctype raises ProgrammingError.""" + + with pytest.raises(ProgrammingError) as exc_info: + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8', ctype=999) + + assert "Invalid ctype" in str(exc_info.value), "Should raise ProgrammingError for invalid ctype" + assert "999" in str(exc_info.value), "Error message should include the invalid ctype value" + +def test_setdecoding_closed_connection(conn_str): + """Test setdecoding on closed connection raises InterfaceError.""" + + temp_conn = connect(conn_str) + temp_conn.close() + + with pytest.raises(InterfaceError) as exc_info: + temp_conn.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') + + assert "Connection is closed" in str(exc_info.value), "Should raise InterfaceError for closed connection" + +def test_setdecoding_constants_access(): + """Test that SQL constants are accessible.""" + + # Test constants exist and have correct values + assert hasattr(mssql_python, 'SQL_CHAR'), "SQL_CHAR constant should be available" + assert hasattr(mssql_python, 'SQL_WCHAR'), "SQL_WCHAR constant should be available" + assert hasattr(mssql_python, 'SQL_WMETADATA'), "SQL_WMETADATA constant should be available" + + assert mssql_python.SQL_CHAR == 1, "SQL_CHAR should have value 1" + assert mssql_python.SQL_WCHAR == -8, "SQL_WCHAR should have value -8" + assert mssql_python.SQL_WMETADATA == -99, "SQL_WMETADATA should have value -99" + +def test_setdecoding_with_constants(db_connection): + """Test setdecoding using module constants.""" + + # Test with SQL_CHAR constant + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8', ctype=mssql_python.SQL_CHAR) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['ctype'] == mssql_python.SQL_CHAR, "Should accept SQL_CHAR constant" + + # Test with SQL_WCHAR constant + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16le', ctype=mssql_python.SQL_WCHAR) + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['ctype'] == mssql_python.SQL_WCHAR, "Should accept SQL_WCHAR constant" + + # Test with SQL_WMETADATA constant + db_connection.setdecoding(mssql_python.SQL_WMETADATA, encoding='utf-16be') + settings = db_connection.getdecoding(mssql_python.SQL_WMETADATA) + assert settings['encoding'] == 'utf-16be', "Should accept SQL_WMETADATA constant" + +def test_setdecoding_common_encodings(db_connection): + """Test setdecoding with various common encodings.""" + + common_encodings = [ + 'utf-8', + 'utf-16le', + 'utf-16be', + 'utf-16', + 'latin-1', + 'ascii', + 'cp1252' + ] + + for encoding in common_encodings: + try: + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding=encoding) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == encoding, f"Failed to set SQL_CHAR decoding to {encoding}" + + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding=encoding) + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['encoding'] == encoding, f"Failed to set SQL_WCHAR decoding to {encoding}" + except Exception as e: + pytest.fail(f"Failed to set valid encoding {encoding}: {e}") + +def test_setdecoding_case_insensitive_encoding(db_connection): + """Test setdecoding with case variations normalizes encoding.""" + + # Test various case formats + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='UTF-8') + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'utf-8', "Encoding should be normalized to lowercase" + + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='Utf-16LE') + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['encoding'] == 'utf-16le', "Encoding should be normalized to lowercase" + +def test_setdecoding_independent_sql_types(db_connection): + """Test that decoding settings for different SQL types are independent.""" + + # Set different encodings for each SQL type + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16le') + db_connection.setdecoding(mssql_python.SQL_WMETADATA, encoding='utf-16be') + + # Verify each maintains its own settings + sql_char_settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + sql_wchar_settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + sql_wmetadata_settings = db_connection.getdecoding(mssql_python.SQL_WMETADATA) + + assert sql_char_settings['encoding'] == 'utf-8', "SQL_CHAR should maintain utf-8" + assert sql_wchar_settings['encoding'] == 'utf-16le', "SQL_WCHAR should maintain utf-16le" + assert sql_wmetadata_settings['encoding'] == 'utf-16be', "SQL_WMETADATA should maintain utf-16be" + +def test_setdecoding_override_previous(db_connection): + """Test setdecoding overrides previous settings for the same SQL type.""" + + # Set initial decoding + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'utf-8', "Initial encoding should be utf-8" + assert settings['ctype'] == mssql_python.SQL_CHAR, "Initial ctype should be SQL_CHAR" + + # Override with different settings + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='latin-1', ctype=mssql_python.SQL_WCHAR) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'latin-1', "Encoding should be overridden to latin-1" + assert settings['ctype'] == mssql_python.SQL_WCHAR, "ctype should be overridden to SQL_WCHAR" + +def test_getdecoding_invalid_sqltype(db_connection): + """Test getdecoding with invalid sqltype raises ProgrammingError.""" + + with pytest.raises(ProgrammingError) as exc_info: + db_connection.getdecoding(999) + + assert "Invalid sqltype" in str(exc_info.value), "Should raise ProgrammingError for invalid sqltype" + assert "999" in str(exc_info.value), "Error message should include the invalid sqltype value" + +def test_getdecoding_closed_connection(conn_str): + """Test getdecoding on closed connection raises InterfaceError.""" + + temp_conn = connect(conn_str) + temp_conn.close() + + with pytest.raises(InterfaceError) as exc_info: + temp_conn.getdecoding(mssql_python.SQL_CHAR) + + assert "Connection is closed" in str(exc_info.value), "Should raise InterfaceError for closed connection" + +def test_getdecoding_returns_copy(db_connection): + """Test getdecoding returns a copy (not reference).""" + + # Set custom decoding + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') + + # Get settings twice + settings1 = db_connection.getdecoding(mssql_python.SQL_CHAR) + settings2 = db_connection.getdecoding(mssql_python.SQL_CHAR) + + # Should be equal but not the same object + assert settings1 == settings2, "Settings should be equal" + assert settings1 is not settings2, "Settings should be different objects" + + # Modifying one shouldn't affect the other + settings1['encoding'] = 'modified' + assert settings2['encoding'] != 'modified', "Modification should not affect other copy" + +def test_setdecoding_getdecoding_consistency(db_connection): + """Test that setdecoding and getdecoding work consistently together.""" + + test_cases = [ + (mssql_python.SQL_CHAR, 'utf-8', mssql_python.SQL_CHAR), + (mssql_python.SQL_CHAR, 'utf-16le', mssql_python.SQL_WCHAR), + (mssql_python.SQL_WCHAR, 'latin-1', mssql_python.SQL_CHAR), + (mssql_python.SQL_WCHAR, 'utf-16be', mssql_python.SQL_WCHAR), + (mssql_python.SQL_WMETADATA, 'utf-16le', mssql_python.SQL_WCHAR), + ] + + for sqltype, encoding, expected_ctype in test_cases: + db_connection.setdecoding(sqltype, encoding=encoding) + settings = db_connection.getdecoding(sqltype) + assert settings['encoding'] == encoding.lower(), f"Encoding should be {encoding.lower()}" + assert settings['ctype'] == expected_ctype, f"ctype should be {expected_ctype}" + +def test_setdecoding_persistence_across_cursors(db_connection): + """Test that decoding settings persist across cursor operations.""" + + # Set custom decoding settings + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='latin-1', ctype=mssql_python.SQL_CHAR) + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16be', ctype=mssql_python.SQL_WCHAR) + + # Create cursors and verify settings persist + cursor1 = db_connection.cursor() + char_settings1 = db_connection.getdecoding(mssql_python.SQL_CHAR) + wchar_settings1 = db_connection.getdecoding(mssql_python.SQL_WCHAR) + + cursor2 = db_connection.cursor() + char_settings2 = db_connection.getdecoding(mssql_python.SQL_CHAR) + wchar_settings2 = db_connection.getdecoding(mssql_python.SQL_WCHAR) + + # Settings should persist across cursor creation + assert char_settings1 == char_settings2, "SQL_CHAR settings should persist across cursors" + assert wchar_settings1 == wchar_settings2, "SQL_WCHAR settings should persist across cursors" + + assert char_settings1['encoding'] == 'latin-1', "SQL_CHAR encoding should remain latin-1" + assert wchar_settings1['encoding'] == 'utf-16be', "SQL_WCHAR encoding should remain utf-16be" + + cursor1.close() + cursor2.close() + +def test_setdecoding_before_and_after_operations(db_connection): + """Test that setdecoding works both before and after database operations.""" + cursor = db_connection.cursor() + + try: + # Initial decoding setting + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') + + # Perform database operation + cursor.execute("SELECT 'Initial test' as message") + result1 = cursor.fetchone() + assert result1[0] == 'Initial test', "Initial operation failed" + + # Change decoding after operation + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='latin-1') + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'latin-1', "Failed to change decoding after operation" + + # Perform another operation with new decoding + cursor.execute("SELECT 'Changed decoding test' as message") + result2 = cursor.fetchone() + assert result2[0] == 'Changed decoding test', "Operation after decoding change failed" + + except Exception as e: + pytest.fail(f"Decoding change test failed: {e}") + finally: + cursor.close() + +def test_setdecoding_all_sql_types_independently(conn_str): + """Test setdecoding with all SQL types on a fresh connection.""" + + conn = connect(conn_str) + try: + # Test each SQL type with different configurations + test_configs = [ + (mssql_python.SQL_CHAR, 'ascii', mssql_python.SQL_CHAR), + (mssql_python.SQL_WCHAR, 'utf-16le', mssql_python.SQL_WCHAR), + (mssql_python.SQL_WMETADATA, 'utf-16be', mssql_python.SQL_WCHAR), + ] + + for sqltype, encoding, ctype in test_configs: + conn.setdecoding(sqltype, encoding=encoding, ctype=ctype) + settings = conn.getdecoding(sqltype) + assert settings['encoding'] == encoding, f"Failed to set encoding for sqltype {sqltype}" + assert settings['ctype'] == ctype, f"Failed to set ctype for sqltype {sqltype}" + + finally: + conn.close() + +def test_setdecoding_security_logging(db_connection): + """Test that setdecoding logs invalid attempts safely.""" + + # These should raise exceptions but not crash due to logging + test_cases = [ + (999, 'utf-8', None), # Invalid sqltype + (mssql_python.SQL_CHAR, 'invalid-encoding', None), # Invalid encoding + (mssql_python.SQL_CHAR, 'utf-8', 999), # Invalid ctype + ] + + for sqltype, encoding, ctype in test_cases: + with pytest.raises(ProgrammingError): + db_connection.setdecoding(sqltype, encoding=encoding, ctype=ctype) + +@pytest.mark.skip("Skipping Unicode data tests till we have support for Unicode") +def test_setdecoding_with_unicode_data(db_connection): + """Test setdecoding with actual Unicode data operations.""" + + # Test different decoding configurations with Unicode data + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16le') + + cursor = db_connection.cursor() + + try: + # Create test table with both CHAR and NCHAR columns + cursor.execute(""" + CREATE TABLE #test_decoding_unicode ( + char_col VARCHAR(100), + nchar_col NVARCHAR(100) + ) + """) + + # Test various Unicode strings + test_strings = [ + "Hello, World!", + "Hello, 世界!", # Chinese + "Привет, мир!", # Russian + "مرحبا بالعالم", # Arabic + ] + + for test_string in test_strings: + # Insert data + cursor.execute( + "INSERT INTO #test_decoding_unicode (char_col, nchar_col) VALUES (?, ?)", + test_string, test_string + ) + + # Retrieve and verify + cursor.execute("SELECT char_col, nchar_col FROM #test_decoding_unicode WHERE char_col = ?", test_string) + result = cursor.fetchone() + + assert result is not None, f"Failed to retrieve Unicode string: {test_string}" + assert result[0] == test_string, f"CHAR column mismatch: expected {test_string}, got {result[0]}" + assert result[1] == test_string, f"NCHAR column mismatch: expected {test_string}, got {result[1]}" + + # Clear for next test + cursor.execute("DELETE FROM #test_decoding_unicode") + + except Exception as e: + pytest.fail(f"Unicode data test failed with custom decoding: {e}") + finally: + try: + cursor.execute("DROP TABLE #test_decoding_unicode") + except: + pass + cursor.close() \ No newline at end of file