From b96cb38e4c387a2d075dfdf6ecfac3edb12697d0 Mon Sep 17 00:00:00 2001 From: Francesco Faraone Date: Tue, 1 Apr 2025 18:17:31 +0200 Subject: [PATCH] cast value to column datatype if needed --- src/requela/builders/sqlalchemy.py | 42 +++++++++++++++++++++++------ tests/sqlalchemy/test_comparison.py | 2 ++ 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/src/requela/builders/sqlalchemy.py b/src/requela/builders/sqlalchemy.py index 7cd22c9..4750599 100644 --- a/src/requela/builders/sqlalchemy.py +++ b/src/requela/builders/sqlalchemy.py @@ -1,5 +1,7 @@ from collections.abc import Callable, Sequence from datetime import date, datetime +from decimal import Decimal +from typing import Any from sqlalchemy import ( BooleanClauseList, @@ -61,7 +63,7 @@ def apply_eq( if value is True or value is False or value is None: return model_field.is_(value) - return model_field == value + return model_field == self.cast_value(model_field, value) def apply_eq_to_relationship( self, relationship_property: RelationshipProperty @@ -81,7 +83,7 @@ def apply_ne( return self.apply_ne_to_relationship(model_field.property) if value is True or value is False or value is None: return model_field.isnot(value) - return model_field != value + return model_field != self.cast_value(model_field, value) def apply_ne_to_relationship( self, relationship_property: RelationshipProperty @@ -92,30 +94,34 @@ def apply_ne_to_relationship( return or_(*conditions) def apply_gt(self, prop: str, value: date | datetime | int | float) -> ColumnExpressionArgument: - return self.resolve_property(prop) > value + return self.resolve_property(prop) > self.cast_value(self.resolve_property(prop), value) def apply_lt(self, prop: str, value: date | datetime | int | float) -> ColumnExpressionArgument: - return self.resolve_property(prop) < value + return self.resolve_property(prop) < self.cast_value(self.resolve_property(prop), value) def apply_gte( self, prop: str, value: date | datetime | int | float ) -> ColumnExpressionArgument: - return self.resolve_property(prop) >= value + return self.resolve_property(prop) >= self.cast_value(self.resolve_property(prop), value) def apply_lte( self, prop: str, value: date | datetime | int | float ) -> ColumnExpressionArgument: - return self.resolve_property(prop) <= value + return self.resolve_property(prop) <= self.cast_value(self.resolve_property(prop), value) def apply_in( self, prop: str, value: Sequence[str] | Sequence[float] | Sequence[int] ) -> ColumnExpressionArgument: - return self.resolve_property(prop).in_(value) + return self.resolve_property(prop).in_( + [self.cast_value(self.resolve_property(prop), item) for item in value] + ) def apply_out( self, prop: str, value: Sequence[str] | Sequence[float] | Sequence[int] ) -> ColumnExpressionArgument: - return self.resolve_property(prop).not_in(value) + return self.resolve_property(prop).not_in( + [self.cast_value(self.resolve_property(prop), item) for item in value] + ) def apply_like(self, prop: str, value: str) -> ColumnExpressionArgument: sql_pattern = value.replace("*", "%") @@ -167,6 +173,26 @@ def apply_any( ) return exists_clause + def cast_value(self, column: ColumnElement, value: Any) -> Any: # pragma: no cover + try: + if column.type.python_type is str and not isinstance(value, str): + return str(value) + if column.type.python_type is bool and not isinstance(value, bool): + return bool(value) + if column.type.python_type is int and not isinstance(value, int): + return int(value) + if column.type.python_type is float and not isinstance(value, float): + return float(value) + if column.type.python_type is Decimal and not isinstance(value, Decimal): + return Decimal(value) + if column.type.python_type is datetime and not isinstance(value, datetime): + return datetime.fromisoformat(value) + if column.type.python_type is date and not isinstance(value, date): + return date.fromisoformat(value) + return value + except Exception: + raise ValueError(f"Cannot cast value {value} to {column.type.python_type}") + def _adapt_condition(self, condition, alias): if isinstance(condition, BooleanClauseList): return condition.operator( diff --git a/tests/sqlalchemy/test_comparison.py b/tests/sqlalchemy/test_comparison.py index 1cd4ab4..4f1b98a 100644 --- a/tests/sqlalchemy/test_comparison.py +++ b/tests/sqlalchemy/test_comparison.py @@ -14,6 +14,7 @@ [ (User, "name", "Ratatouille 123", select(User).filter(User.name == "Ratatouille 123")), (User, "age", 25, select(User).filter(User.age == 25)), + (Account, "name", 25, select(Account).filter(Account.name == "25")), (Account, "balance", 25.13, select(Account).filter(Account.balance == 25.13)), (Account, "balance", -71.14, select(Account).filter(Account.balance == -71.14)), (Account, "balance", 33, select(Account).filter(Account.balance == 33)), @@ -50,6 +51,7 @@ def test_comparison_eq(model, field, value, expected): [ (User, "name", "Ratatouille", select(User).filter(User.name != "Ratatouille")), (User, "age", 25, select(User).filter(User.age != 25)), + (Account, "name", 25, select(Account).filter(Account.name != "25")), (Account, "balance", 25.13, select(Account).filter(Account.balance != 25.13)), (Account, "balance", -71.14, select(Account).filter(Account.balance != -71.14)), (Account, "balance", 33, select(Account).filter(Account.balance != 33)),