From 7f63f08c497f32ab4672d040e9ae297867b1bb56 Mon Sep 17 00:00:00 2001 From: Peng Ren Date: Mon, 2 Feb 2026 20:56:08 -0500 Subject: [PATCH 1/2] Bump version --- pymongosql/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymongosql/__init__.py b/pymongosql/__init__.py index d377b5a..34b8182 100644 --- a/pymongosql/__init__.py +++ b/pymongosql/__init__.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from .connection import Connection -__version__: str = "0.4.0" +__version__: str = "0.4.1" # Globals https://www.python.org/dev/peps/pep-0249/#globals apilevel: str = "2.0" From ed98d6c916b06ed94fd0ef59e58e1d54eeacd315 Mon Sep 17 00:00:00 2001 From: Peng Ren Date: Fri, 6 Feb 2026 17:34:29 -0500 Subject: [PATCH 2/2] Support value functions --- README.md | 40 +++- pymongosql/sql/handler.py | 88 ++++++- pymongosql/sql/value_function_registry.py | 271 +++++++++++++++++++++ tests/test_cursor.py | 57 +++++ tests/test_value_function_registry.py | 274 ++++++++++++++++++++++ 5 files changed, 720 insertions(+), 10 deletions(-) create mode 100644 pymongosql/sql/value_function_registry.py create mode 100644 tests/test_value_function_registry.py diff --git a/README.md b/README.md index 7483d44..1efb4ea 100644 --- a/README.md +++ b/README.md @@ -223,6 +223,37 @@ Parameters are substituted into the MongoDB filter during execution, providing p - **Logical operators**: `WHERE age > 18 AND status = 'active'`, `WHERE age < 30 OR role = 'admin'` - **Nested field filtering**: `WHERE profile.status = 'active'` - **Array filtering**: `WHERE items[0].price > 100` +- **Value Functions**: Apply transformations to values in WHERE clauses for filtering + +#### Value Functions + +PyMongoSQL supports value functions to transform and filter values in WHERE clauses. Built-in value functions include: + +**str_to_datetime()** - Convert ISO 8601 or custom formatted strings to Python datetime objects + +```python +# ISO 8601 format +cursor.execute("SELECT * FROM events WHERE created_at >= str_to_datetime('2024-01-15T10:30:00Z')") + +# Custom format +cursor.execute("SELECT * FROM events WHERE created_at < str_to_datetime('03/15/2024', '%m/%d/%Y')") +``` + +**str_to_timestamp()** - Convert ISO 8601 or custom formatted strings to BSON Timestamp objects + +```python +# ISO 8601 format +cursor.execute("SELECT * FROM logs WHERE timestamp > str_to_timestamp('2024-01-15T00:00:00Z')") + +# Custom format +cursor.execute("SELECT * FROM logs WHERE timestamp < str_to_timestamp('01/15/2024', '%m/%d/%Y')") +``` + +Both functions: +- Support ISO 8601 strings with 'Z' timezone indicator +- Support custom format strings using Python strftime directives +- Return values with UTC timezone +- Can be combined with standard SQL operators (>, <, >=, <=, =, !=) ### Nested Field Support - **Single-level**: `profile.name`, `settings.theme` @@ -505,15 +536,6 @@ PyMongoSQL can be used as a database driver in Apache Superset for querying and This allows seamless integration between MongoDB data and Superset's BI capabilities without requiring data migration to traditional SQL databases. -## Limitations & Roadmap - -**Note**: PyMongoSQL currently supports DQL (Data Query Language) and DML (Data Manipulation Language) operations. The following SQL features are **not yet supported** but are planned for future releases: - -- **Advanced DML Operations** - - `REPLACE`, `MERGE`, `UPSERT` - -These features are on our development roadmap and contributions are welcome! - ## Contributing Contributions are welcome! Please feel free to submit a Pull Request. For major changes, please open an issue first to discuss what you would like to change. diff --git a/pymongosql/sql/handler.py b/pymongosql/sql/handler.py index 794c0ca..19461fe 100644 --- a/pymongosql/sql/handler.py +++ b/pymongosql/sql/handler.py @@ -401,13 +401,99 @@ def _extract_value(self, ctx: Any) -> Any: if operator: parts = self._split_by_operator(text, operator) if len(parts) >= 2: - return self._parse_value(parts[1].strip("()")) + value_text = parts[1].strip() + # Check if value is a function call + return self._extract_value_or_function(value_text) return None except Exception as e: _logger.debug(f"Failed to extract value: {e}") return None + def _extract_value_or_function(self, value_text: str) -> Any: + """ + Extract value, which could be a literal or a value function call. + + Detects and executes value functions like str_to_datetime(...), str_to_timestamp(...). + + Args: + value_text: Text representing the value, possibly a function call + + Returns: + Processed value (result of function execution if function, otherwise parsed literal) + """ + value_text = value_text.strip() + + # Check if this looks like a function call: func_name(...) + if "(" in value_text and value_text.endswith(")"): + # Extract function name and arguments + paren_pos = value_text.find("(") + func_name = value_text[:paren_pos].strip() + + # Check if it's a valid identifier (function name) + if func_name.isidentifier(): + try: + from .value_function_registry import get_default_registry + + registry = get_default_registry() + if registry.has_function(func_name): + # It's a registered value function - execute it + args_text = value_text[paren_pos + 1 : -1] + args = self._parse_function_arguments(args_text) + result = registry.execute(func_name, args) + _logger.debug(f"Executed value function: {func_name}({args}) -> {result}") + return result + except Exception as e: + _logger.warning(f"Failed to execute value function '{func_name}': {e}") + # Fall through to treat as regular value + + # Not a function call or function execution failed - treat as regular value + return self._parse_value(value_text) + + def _parse_function_arguments(self, args_text: str) -> list: + """ + Parse function arguments from comma-separated string. + + Handles string literals with quotes and nested structures. + + Args: + args_text: Text of function arguments (e.g., "'2024-01-01', '%m/%d/%Y'") + + Returns: + List of parsed argument values + """ + if not args_text.strip(): + return [] + + args = [] + current_arg = "" + in_quotes = False + quote_char = None + + for char in args_text: + if char in ('"', "'") and not in_quotes: + in_quotes = True + quote_char = char + current_arg += char + elif char == quote_char and in_quotes: + in_quotes = False + quote_char = None + current_arg += char + elif char == "," and not in_quotes: + # End of argument + arg = current_arg.strip() + if arg: + args.append(self._parse_value(arg)) + current_arg = "" + else: + current_arg += char + + # Don't forget the last argument + if current_arg.strip(): + args.append(self._parse_value(current_arg.strip())) + + return args + def _extract_in_values(self, text: str) -> List[Any]: """Extract values from IN clause""" # Handle both 'IN(' and 'IN (' patterns diff --git a/pymongosql/sql/value_function_registry.py b/pymongosql/sql/value_function_registry.py new file mode 100644 index 0000000..615456b --- /dev/null +++ b/pymongosql/sql/value_function_registry.py @@ -0,0 +1,271 @@ +# -*- coding: utf-8 -*- +import logging +from datetime import datetime, timezone +from typing import Any, Callable, Dict, List, Optional + +from bson.timestamp import Timestamp + +_logger = logging.getLogger(__name__) + + +class ValueFunctionExecutionError(Exception): + """Raised when a value function execution fails""" + + pass + + +class ValueFunctionRegistry: + """Registry for managing custom value transformation functions""" + + def __init__(self): + """Initialize the registry with built-in functions""" + self._functions: Dict[str, Callable] = {} + self._register_builtin_functions() + + def _register_builtin_functions(self) -> None: + """Register built-in value transformation functions""" + self.register("str_to_datetime", self.str_to_datetime) + self.register("str_to_timestamp", self.str_to_timestamp) + + def register(self, func_name: str, func: Callable) -> None: + """ + Register a custom value function. + + Args: + func_name: Name of the function (case-insensitive) + func: Callable that takes arguments and returns transformed value + + Raises: + ValueError: If func_name is already registered or invalid + """ + if not isinstance(func_name, str) or not func_name.strip(): + raise ValueError("Function name must be a non-empty string") + + if not callable(func): + raise ValueError(f"Function {func_name} must be callable") + + func_name_lower = func_name.lower() + if func_name_lower in self._functions: + _logger.warning(f"Overwriting existing function: {func_name}") + + self._functions[func_name_lower] = func + _logger.debug(f"Registered value function: {func_name}") + + def unregister(self, func_name: str) -> None: + """ + Unregister a value function. + + Args: + func_name: Name of the function to unregister + """ + func_name_lower = func_name.lower() + if func_name_lower in self._functions: + del self._functions[func_name_lower] + _logger.debug(f"Unregistered value function: {func_name}") + + def execute(self, func_name: str, args: List[Any]) -> Any: + """ + Execute a registered value function. + + Args: + func_name: Name of the function to execute + args: List of arguments to pass to the function + + Returns: + The result of the function execution + + Raises: + ValueFunctionExecutionError: If function not found or execution fails + """ + func_name_lower = func_name.lower() + + if func_name_lower not in self._functions: + raise ValueFunctionExecutionError( + f"Value function '{func_name}' not found. " f"Available functions: {list(self._functions.keys())}" + ) + + try: + func = self._functions[func_name_lower] + result = func(*args) + _logger.debug(f"Executed value function: {func_name}({args}) -> {result}") + return result + except TypeError as e: + raise ValueFunctionExecutionError(f"Invalid arguments for function '{func_name}': {str(e)}") from e + except Exception as e: + raise ValueFunctionExecutionError(f"Error executing function '{func_name}': {str(e)}") from e + + def has_function(self, func_name: str) -> bool: + """Check if a function is registered""" + return func_name.lower() in self._functions + + def list_functions(self) -> List[str]: + """Get list of registered function names""" + return list(self._functions.keys()) + + # ========================================================================= + # Built-in Value Functions + # ========================================================================= + + @staticmethod + def str_to_datetime(*args) -> datetime: + """ + Convert string value to Python datetime object. + + Supports two signatures: + - str_to_datetime(val): Convert ISO 8601 formatted string to datetime + - str_to_datetime(val, format): Convert string using custom format + + Args: + *args: Either (val,) or (val, format) + val: String representation of datetime + format: Python datetime format string (Python strftime directives) + + Returns: + datetime: Python datetime object with UTC timezone + + Raises: + ValueError: If arguments are invalid or parsing fails + + Examples: + str_to_datetime('2024-01-15') # ISO 8601 + str_to_datetime('2024-01-15T10:30:00Z') # ISO 8601 with time + str_to_datetime('01/15/2024', '%m/%d/%Y') # Custom format + """ + if not args: + raise ValueError("str_to_datetime() requires at least 1 argument (val)") + + if len(args) > 2: + raise ValueError(f"str_to_datetime() takes at most 2 arguments ({len(args)} given)") + + val = args[0] + + # Validate input + if not isinstance(val, str): + raise ValueError(f"str_to_datetime() val must be string, got {type(val).__name__}") + + val = val.strip() + if val.endswith("Z"): + val = val[:-1] + "+00:00" + + try: + if len(args) == 1: + # ISO 8601 format - use fromisoformat for standard parsing + dt = datetime.fromisoformat(val) + # Ensure UTC timezone + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt + else: + # Custom format + format_str = args[1] + if not isinstance(format_str, str): + raise ValueError(f"str_to_datetime() format must be string, got {type(format_str).__name__}") + dt = datetime.strptime(val, format_str.strip()) + # Ensure UTC timezone + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt + + except ValueError as e: + raise ValueError(f"Failed to parse str_to_datetime from '{val}': {str(e)}") from e + + @staticmethod + def str_to_timestamp(*args) -> Timestamp: + """ + Convert string value to BSON Timestamp object. + + Supports two signatures: + - str_to_timestamp(val): Convert ISO 8601 formatted string to Timestamp + - str_to_timestamp(val, format): Convert string using custom format + + Args: + *args: Either (val,) or (val, format) + val: String representation of datetime + format: Python datetime format string (Python strftime directives) + + Returns: + bson.timestamp.Timestamp: BSON Timestamp object + + Raises: + ValueError: If arguments are invalid or parsing fails + + Notes: + - Timestamp uses Unix epoch seconds and an increment counter + - This function uses seconds from epoch and increments by 1 for the counter + - Used primarily for MongoDB replication operations + + Examples: + str_to_timestamp('2024-01-15') # ISO 8601 + str_to_timestamp('2024-01-15T10:30:00Z') # ISO 8601 with time + str_to_timestamp('01/15/2024', '%m/%d/%Y') # Custom format + """ + if not args: + raise ValueError("str_to_timestamp() requires at least 1 argument (val)") + + if len(args) > 2: + raise ValueError(f"str_to_timestamp() takes at most 2 arguments ({len(args)} given)") + + val = args[0] + + # Validate input + if not isinstance(val, str): + raise ValueError(f"str_to_timestamp() val must be string, got {type(val).__name__}") + + val = val.strip() + if val.endswith("Z"): + val = val[:-1] + "+00:00" + + try: + # First parse to datetime using same logic as datetime function + if len(args) == 1: + # ISO 8601 format + dt = datetime.fromisoformat(val) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + else: + # Custom format + format_str = args[1] + if not isinstance(format_str, str): + raise ValueError(f"str_to_timestamp() format must be string, got {type(format_str).__name__}") + dt = datetime.strptime(val, format_str.strip()) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + + # Convert datetime to Timestamp + # Timestamp(time, inc) where time is Unix epoch in seconds + time_seconds = int(dt.timestamp()) + # Use increment of 1 for conversion operations + ts = Timestamp(time=time_seconds, inc=1) + return ts + + except ValueError as e: + raise ValueError(f"Failed to parse str_to_timestamp from '{val}': {str(e)}") from e + + +# Global singleton instance +_default_registry: Optional[ValueFunctionRegistry] = None + + +def get_default_registry() -> ValueFunctionRegistry: + """Get or create the default value function registry""" + global _default_registry + if _default_registry is None: + _default_registry = ValueFunctionRegistry() + return _default_registry + + +def execute_value_function(func_name: str, args: List[Any]) -> Any: + """ + Execute a value function using the default registry. + + Args: + func_name: Name of the function + args: Arguments to pass to the function + + Returns: + Result of function execution + + Raises: + ValueFunctionExecutionError: If function execution fails + """ + return get_default_registry().execute(func_name, args) diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 2584f28..0d49a5d 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -467,3 +467,60 @@ def test_execute_with_reserved_keyword_field_in_where(self, conn): rows = cursor.result_set.fetchall() assert len(rows) == 3 assert len(rows[0]) == 1 + + def test_execute_with_value_function_datetime_filter(self, conn): + """Test filtering with value function str_to_datetime() in WHERE clause""" + sql = """ + SELECT name, created_at FROM users + WHERE created_at >= str_to_datetime('2023-01-01T00:00:00Z') + AND created_at < str_to_datetime('2023/03/01', '%Y/%m/%d') + """ + cursor = conn.cursor() + result = cursor.execute(sql) + + assert result == cursor + assert isinstance(cursor.result_set, ResultSet) + + rows = cursor.result_set.fetchall() + assert len(rows) == 3 + + col_names = [desc[0] for desc in cursor.result_set.description] + assert "created_at" in col_names + created_at_idx = col_names.index("created_at") + + for row in rows: + created_at = row[created_at_idx] + assert isinstance(created_at, datetime) + if created_at.tzinfo is None: + created_at = created_at.replace(tzinfo=timezone.utc) + assert created_at >= datetime(2023, 1, 1, tzinfo=timezone.utc) + assert created_at < datetime(2023, 3, 1, tzinfo=timezone.utc) + + def test_execute_with_value_function_timestamp_filter(self, conn): + """Test filtering with value function str_to_timestamp() in WHERE clause""" + sql = """ + SELECT name, "date" FROM users + WHERE "date" > str_to_timestamp('2025-01-01T00:00:00Z') + AND "date" < str_to_timestamp('2026/01/01', '%Y/%m/%d') + """ + cursor = conn.cursor() + result = cursor.execute(sql) + + assert result == cursor + assert isinstance(cursor.result_set, ResultSet) + + rows = cursor.result_set.fetchall() + assert len(rows) == 3 + + col_names = [desc[0] for desc in cursor.result_set.description] + assert "date" in col_names + date_idx = col_names.index("date") + + start = Timestamp(int(datetime(2025, 1, 1, tzinfo=timezone.utc).timestamp()), 1) + end = Timestamp(int(datetime(2026, 1, 1, tzinfo=timezone.utc).timestamp()), 1) + + for row in rows: + date_value = row[date_idx] + assert isinstance(date_value, Timestamp) + assert date_value > start + assert date_value < end diff --git a/tests/test_value_function_registry.py b/tests/test_value_function_registry.py new file mode 100644 index 0000000..84278d7 --- /dev/null +++ b/tests/test_value_function_registry.py @@ -0,0 +1,274 @@ +# -*- coding: utf-8 -*- +from datetime import datetime, timezone + +import pytest +from bson.timestamp import Timestamp + +from pymongosql.sql.value_function_registry import ( + ValueFunctionExecutionError, + ValueFunctionRegistry, +) + + +@pytest.fixture +def registry(): + """Fixture providing a fresh ValueFunctionRegistry for each test""" + return ValueFunctionRegistry() + + +class TestValueFunctionRegistry: + """Test cases for ValueFunctionRegistry""" + + def test_registry_initialization(self, registry): + """Test that registry initializes with built-in functions""" + assert registry.has_function("str_to_datetime") + assert registry.has_function("str_to_timestamp") + + def test_list_functions(self, registry): + """Test listing registered functions""" + functions = registry.list_functions() + assert "str_to_datetime" in functions + assert "str_to_timestamp" in functions + + def test_function_case_insensitive(self, registry): + """Test that function names are case-insensitive""" + assert registry.has_function("STR_TO_DATETIME") + assert registry.has_function("Str_To_Datetime") + assert registry.has_function("STR_TO_TIMESTAMP") + + +class TestDatetimeFunction: + """Test cases for str_to_datetime() function""" + + def test_datetime_iso8601_basic(self, registry): + """Test datetime conversion from ISO 8601 format""" + result = registry.execute("str_to_datetime", ["2024-01-15"]) + assert isinstance(result, datetime) + assert result.year == 2024 + assert result.month == 1 + assert result.day == 15 + assert result.tzinfo == timezone.utc + + def test_datetime_iso8601_with_time(self, registry): + """Test datetime conversion from ISO 8601 with time""" + result = registry.execute("str_to_datetime", ["2024-01-15T10:30:45"]) + assert isinstance(result, datetime) + assert result.year == 2024 + assert result.month == 1 + assert result.day == 15 + assert result.hour == 10 + assert result.minute == 30 + assert result.second == 45 + assert result.tzinfo == timezone.utc + + def test_datetime_iso8601_with_z(self, registry): + """Test datetime conversion from ISO 8601 with Z timezone""" + result = registry.execute("str_to_datetime", ["2024-01-15T10:30:45Z"]) + assert isinstance(result, datetime) + assert result.year == 2024 + assert result.month == 1 + assert result.day == 15 + assert result.hour == 10 + assert result.minute == 30 + assert result.second == 45 + + def test_datetime_custom_format(self, registry): + """Test datetime conversion with custom format""" + result = registry.execute("str_to_datetime", ["01/15/2024", "%m/%d/%Y"]) + assert isinstance(result, datetime) + assert result.year == 2024 + assert result.month == 1 + assert result.day == 15 + assert result.tzinfo == timezone.utc + + def test_datetime_custom_format_with_time(self, registry): + """Test datetime conversion with custom format including time""" + result = registry.execute("str_to_datetime", ["01/15/2024 10:30:45", "%m/%d/%Y %H:%M:%S"]) + assert isinstance(result, datetime) + assert result.year == 2024 + assert result.month == 1 + assert result.day == 15 + assert result.hour == 10 + assert result.minute == 30 + assert result.second == 45 + + def test_datetime_invalid_format(self, registry): + """Test datetime with invalid format raises error""" + with pytest.raises(ValueFunctionExecutionError): + registry.execute("str_to_datetime", ["invalid-date"]) + + def test_datetime_invalid_format_string(self, registry): + """Test datetime with invalid format string raises error""" + with pytest.raises(ValueFunctionExecutionError): + registry.execute("str_to_datetime", ["01/15/2024", "%Y-%m-%d"]) # Format mismatch + + def test_datetime_missing_argument(self, registry): + """Test datetime with missing argument raises error""" + with pytest.raises(ValueFunctionExecutionError): + registry.execute("str_to_datetime", []) + + def test_datetime_too_many_arguments(self, registry): + """Test datetime with too many arguments raises error""" + with pytest.raises(ValueFunctionExecutionError): + registry.execute("str_to_datetime", ["2024-01-15", "%Y-%m-%d", "extra"]) + + def test_datetime_non_string_value(self, registry): + """Test datetime with non-string value raises error""" + with pytest.raises(ValueFunctionExecutionError): + registry.execute("str_to_datetime", [12345]) + + def test_datetime_whitespace_handling(self, registry): + """Test datetime handles whitespace in input""" + result = registry.execute("str_to_datetime", [" 2024-01-15 "]) + assert isinstance(result, datetime) + assert result.year == 2024 + assert result.month == 1 + assert result.day == 15 + + +class TestTimestampFunction: + """Test cases for str_to_timestamp() function""" + + def test_timestamp_iso8601_basic(self, registry): + """Test timestamp conversion from ISO 8601 format""" + result = registry.execute("str_to_timestamp", ["2024-01-15"]) + assert isinstance(result, Timestamp) + assert result.time is not None + assert result.inc is not None + + def test_timestamp_iso8601_with_time(self, registry): + """Test timestamp conversion from ISO 8601 with time""" + result = registry.execute("str_to_timestamp", ["2024-01-15T10:30:45"]) + assert isinstance(result, Timestamp) + assert result.time is not None + + def test_timestamp_iso8601_with_z(self, registry): + """Test timestamp conversion from ISO 8601 with Z timezone""" + result = registry.execute("str_to_timestamp", ["2024-01-15T10:30:45Z"]) + assert isinstance(result, Timestamp) + assert result.time is not None + + def test_timestamp_custom_format(self, registry): + """Test timestamp conversion with custom format""" + result = registry.execute("str_to_timestamp", ["01/15/2024", "%m/%d/%Y"]) + assert isinstance(result, Timestamp) + assert result.time is not None + + def test_timestamp_custom_format_with_time(self, registry): + """Test timestamp conversion with custom format including time""" + result = registry.execute("str_to_timestamp", ["01/15/2024 10:30:45", "%m/%d/%Y %H:%M:%S"]) + assert isinstance(result, Timestamp) + assert result.time is not None + + def test_timestamp_increment_value(self, registry): + """Test timestamp has increment value of 1""" + result = registry.execute("str_to_timestamp", ["2024-01-15"]) + assert result.inc == 1 + + def test_timestamp_invalid_format(self, registry): + """Test timestamp with invalid format raises error""" + with pytest.raises(ValueFunctionExecutionError): + registry.execute("str_to_timestamp", ["invalid-date"]) + + def test_timestamp_missing_argument(self, registry): + """Test timestamp with missing argument raises error""" + with pytest.raises(ValueFunctionExecutionError): + registry.execute("str_to_timestamp", []) + + def test_timestamp_too_many_arguments(self, registry): + """Test timestamp with too many arguments raises error""" + with pytest.raises(ValueFunctionExecutionError): + registry.execute("str_to_timestamp", ["2024-01-15", "%Y-%m-%d", "extra"]) + + def test_timestamp_non_string_value(self, registry): + """Test timestamp with non-string value raises error""" + with pytest.raises(ValueFunctionExecutionError): + registry.execute("str_to_timestamp", [12345]) + + +class TestCustomFunctionRegistration: + """Test cases for registering custom functions""" + + def test_register_custom_function(self, registry): + """Test registering a custom function""" + + def custom_upper(val): + return val.upper() if isinstance(val, str) else str(val).upper() + + registry.register("upper", custom_upper) + assert registry.has_function("upper") + + def test_execute_custom_function(self, registry): + """Test executing a custom function""" + + def custom_upper(val): + return val.upper() if isinstance(val, str) else str(val).upper() + + registry.register("upper", custom_upper) + result = registry.execute("upper", ["hello"]) + assert result == "HELLO" + + def test_register_invalid_function_name(self, registry): + """Test registering with invalid function name raises error""" + + def dummy(): + pass + + with pytest.raises(ValueError): + registry.register("", dummy) + + with pytest.raises(ValueError): + registry.register(None, dummy) + + def test_register_non_callable(self, registry): + """Test registering non-callable raises error""" + with pytest.raises(ValueError): + registry.register("notfunc", "not a function") + + def test_unregister_function(self, registry): + """Test unregistering a function""" + + def dummy(): + pass + + registry.register("temp", dummy) + assert registry.has_function("temp") + registry.unregister("temp") + assert not registry.has_function("temp") + + def test_unregister_nonexistent_function(self, registry): + """Test unregistering non-existent function doesn't raise error""" + # Should not raise + registry.unregister("nonexistent") + + def test_overwrite_existing_function(self, registry): + """Test overwriting an existing function""" + + def func1(): + return 1 + + def func2(): + return 2 + + registry.register("test", func1) + result1 = registry.execute("test", []) + assert result1 == 1 + + registry.register("test", func2) + result2 = registry.execute("test", []) + assert result2 == 2 + + +class TestFunctionExecutionErrors: + """Test error handling in function execution""" + + def test_nonexistent_function(self, registry): + """Test executing non-existent function raises error""" + with pytest.raises(ValueFunctionExecutionError) as exc_info: + registry.execute("nonexistent", []) + assert "nonexistent" in str(exc_info.value) + + def test_function_with_wrong_argument_count(self, registry): + """Test executing function with wrong argument count raises error""" + with pytest.raises(ValueFunctionExecutionError): + registry.execute("str_to_datetime", ["2024-01-15", "extra", "args"])