Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions src/requela/builders/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections.abc import Callable, Sequence
from datetime import date, datetime
from decimal import Decimal
from typing import Any

from sqlalchemy import (
BooleanClauseList,
Expand Down Expand Up @@ -61,7 +63,7 @@ def apply_eq(

if value is True or value is False or value is None:
return model_field.is_(value)
return model_field == value
return model_field == self.cast_value(model_field, value)

def apply_eq_to_relationship(
self, relationship_property: RelationshipProperty
Expand All @@ -81,7 +83,7 @@ def apply_ne(
return self.apply_ne_to_relationship(model_field.property)
if value is True or value is False or value is None:
return model_field.isnot(value)
return model_field != value
return model_field != self.cast_value(model_field, value)

def apply_ne_to_relationship(
self, relationship_property: RelationshipProperty
Expand All @@ -92,30 +94,34 @@ def apply_ne_to_relationship(
return or_(*conditions)

def apply_gt(self, prop: str, value: date | datetime | int | float) -> ColumnExpressionArgument:
return self.resolve_property(prop) > value
return self.resolve_property(prop) > self.cast_value(self.resolve_property(prop), value)

def apply_lt(self, prop: str, value: date | datetime | int | float) -> ColumnExpressionArgument:
return self.resolve_property(prop) < value
return self.resolve_property(prop) < self.cast_value(self.resolve_property(prop), value)

def apply_gte(
self, prop: str, value: date | datetime | int | float
) -> ColumnExpressionArgument:
return self.resolve_property(prop) >= value
return self.resolve_property(prop) >= self.cast_value(self.resolve_property(prop), value)

def apply_lte(
self, prop: str, value: date | datetime | int | float
) -> ColumnExpressionArgument:
return self.resolve_property(prop) <= value
return self.resolve_property(prop) <= self.cast_value(self.resolve_property(prop), value)

def apply_in(
self, prop: str, value: Sequence[str] | Sequence[float] | Sequence[int]
) -> ColumnExpressionArgument:
return self.resolve_property(prop).in_(value)
return self.resolve_property(prop).in_(
[self.cast_value(self.resolve_property(prop), item) for item in value]
)

def apply_out(
self, prop: str, value: Sequence[str] | Sequence[float] | Sequence[int]
) -> ColumnExpressionArgument:
return self.resolve_property(prop).not_in(value)
return self.resolve_property(prop).not_in(
[self.cast_value(self.resolve_property(prop), item) for item in value]
)

def apply_like(self, prop: str, value: str) -> ColumnExpressionArgument:
sql_pattern = value.replace("*", "%")
Expand Down Expand Up @@ -167,6 +173,26 @@ def apply_any(
)
return exists_clause

def cast_value(self, column: ColumnElement, value: Any) -> Any: # pragma: no cover
try:
if column.type.python_type is str and not isinstance(value, str):
return str(value)
if column.type.python_type is bool and not isinstance(value, bool):
return bool(value)
if column.type.python_type is int and not isinstance(value, int):
return int(value)
if column.type.python_type is float and not isinstance(value, float):
return float(value)
if column.type.python_type is Decimal and not isinstance(value, Decimal):
return Decimal(value)
if column.type.python_type is datetime and not isinstance(value, datetime):
return datetime.fromisoformat(value)
if column.type.python_type is date and not isinstance(value, date):
return date.fromisoformat(value)
return value
except Exception:
raise ValueError(f"Cannot cast value {value} to {column.type.python_type}")

def _adapt_condition(self, condition, alias):
if isinstance(condition, BooleanClauseList):
return condition.operator(
Expand Down
2 changes: 2 additions & 0 deletions tests/sqlalchemy/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
[
(User, "name", "Ratatouille 123", select(User).filter(User.name == "Ratatouille 123")),
(User, "age", 25, select(User).filter(User.age == 25)),
(Account, "name", 25, select(Account).filter(Account.name == "25")),
(Account, "balance", 25.13, select(Account).filter(Account.balance == 25.13)),
(Account, "balance", -71.14, select(Account).filter(Account.balance == -71.14)),
(Account, "balance", 33, select(Account).filter(Account.balance == 33)),
Expand Down Expand Up @@ -50,6 +51,7 @@ def test_comparison_eq(model, field, value, expected):
[
(User, "name", "Ratatouille", select(User).filter(User.name != "Ratatouille")),
(User, "age", 25, select(User).filter(User.age != 25)),
(Account, "name", 25, select(Account).filter(Account.name != "25")),
(Account, "balance", 25.13, select(Account).filter(Account.balance != 25.13)),
(Account, "balance", -71.14, select(Account).filter(Account.balance != -71.14)),
(Account, "balance", 33, select(Account).filter(Account.balance != 33)),
Expand Down
Loading