diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 7c916f79af..3b79893f3b 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -54,7 +54,7 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid -from typing_extensions import Literal, TypeAlias, deprecated, get_origin +from typing_extensions import Annotated, Literal, TypeAlias, deprecated, get_args, get_origin from ._compat import ( # type: ignore[attr-defined] IS_PYDANTIC_V2, @@ -562,7 +562,8 @@ def get_config(name: str) -> Any: # If it was passed by kwargs, ensure it's also set in config set_config_value(model=new_cls, parameter="table", value=config_table) for k, v in get_model_fields(new_cls).items(): - col = get_column_from_field(v) + original_annotation = new_cls.__annotations__.get(k) + col = get_column_from_field(v, original_annotation) setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field # in orm_mode instead of preemptively converting it to a dict. @@ -646,12 +647,44 @@ def __init__( ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) -def get_sqlalchemy_type(field: Any) -> Any: +def _get_sqlmodel_field_info_from_annotation(annotation: Any) -> Optional["FieldInfo"]: + """Extract SQLModel FieldInfo from an Annotated type's metadata. + + When using Annotated[type, Field(...), Validator(...)], Pydantic V2 may create + a new pydantic.fields.FieldInfo that doesn't preserve SQLModel-specific attributes + like sa_column and sa_type. This function looks through the Annotated metadata + to find the original SQLModel FieldInfo. + """ + if get_origin(annotation) is not Annotated: + return None + for arg in get_args(annotation)[1:]: # Skip the first arg (the actual type) + if isinstance(arg, FieldInfo): + return arg + return None + + +def get_sqlalchemy_type(field: Any, original_annotation: Any = None) -> Any: if IS_PYDANTIC_V2: field_info = field else: field_info = field.field_info sa_type = getattr(field_info, "sa_type", Undefined) # noqa: B009 + # If sa_type not found on field_info, check if it's in the Annotated metadata + # This handles the case where Pydantic V2 creates a new FieldInfo losing SQLModel attrs + if sa_type is Undefined and IS_PYDANTIC_V2: + # First try field_info.annotation (may be unpacked by Pydantic) + annotation = getattr(field_info, "annotation", None) + if annotation is not None: + sqlmodel_field_info = _get_sqlmodel_field_info_from_annotation(annotation) + if sqlmodel_field_info is not None: + sa_type = getattr(sqlmodel_field_info, "sa_type", Undefined) + # If still not found, try the original annotation from the class + if sa_type is Undefined and original_annotation is not None: + sqlmodel_field_info = _get_sqlmodel_field_info_from_annotation( + original_annotation + ) + if sqlmodel_field_info is not None: + sa_type = getattr(sqlmodel_field_info, "sa_type", Undefined) if sa_type is not Undefined: return sa_type @@ -703,15 +736,33 @@ def get_sqlalchemy_type(field: Any) -> Any: raise ValueError(f"{type_} has no matching SQLAlchemy type") -def get_column_from_field(field: Any) -> Column: # type: ignore +def get_column_from_field( + field: Any, original_annotation: Any = None +) -> Column: # type: ignore if IS_PYDANTIC_V2: field_info = field else: field_info = field.field_info sa_column = getattr(field_info, "sa_column", Undefined) + # If sa_column not found on field_info, check if it's in the Annotated metadata + # This handles the case where Pydantic V2 creates a new FieldInfo losing SQLModel attrs + if sa_column is Undefined and IS_PYDANTIC_V2: + # First try field_info.annotation (may be unpacked by Pydantic) + annotation = getattr(field_info, "annotation", None) + if annotation is not None: + sqlmodel_field_info = _get_sqlmodel_field_info_from_annotation(annotation) + if sqlmodel_field_info is not None: + sa_column = getattr(sqlmodel_field_info, "sa_column", Undefined) + # If still not found, try the original annotation from the class + if sa_column is Undefined and original_annotation is not None: + sqlmodel_field_info = _get_sqlmodel_field_info_from_annotation( + original_annotation + ) + if sqlmodel_field_info is not None: + sa_column = getattr(sqlmodel_field_info, "sa_column", Undefined) if isinstance(sa_column, Column): return sa_column - sa_type = get_sqlalchemy_type(field) + sa_type = get_sqlalchemy_type(field, original_annotation) primary_key = getattr(field_info, "primary_key", Undefined) if primary_key is Undefined: primary_key = False diff --git a/tests/test_annotated_sa_column.py b/tests/test_annotated_sa_column.py new file mode 100644 index 0000000000..2ac26e8a09 --- /dev/null +++ b/tests/test_annotated_sa_column.py @@ -0,0 +1,94 @@ +"""Tests for Annotated fields with sa_column and Pydantic validators. + +When using Annotated[type, Field(sa_column=...), Validator(...)], Pydantic V2 may +create a new FieldInfo that doesn't preserve SQLModel-specific attributes like +sa_column. These tests ensure the sa_column is properly extracted from the +Annotated metadata. +""" + +from datetime import datetime +from typing import Annotated, Optional + +from pydantic import AfterValidator, BeforeValidator +from sqlalchemy import Column, DateTime, String +from sqlmodel import Field, SQLModel + + +def test_annotated_sa_column_with_validators() -> None: + """Test that sa_column is preserved when using Annotated with validators.""" + + def before_validate(v: datetime) -> datetime: + return v + + def after_validate(v: datetime) -> datetime: + return v + + class Position(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + timestamp: Annotated[ + datetime, + Field( + sa_column=Column( + DateTime(timezone=True), nullable=False, index=True + ) + ), + BeforeValidator(before_validate), + AfterValidator(after_validate), + ] + + # Verify the column type has timezone=True + assert Position.__table__.c.timestamp.type.timezone is True + assert Position.__table__.c.timestamp.nullable is False + assert Position.__table__.c.timestamp.index is True + + +def test_annotated_sa_column_with_single_validator() -> None: + """Test sa_column with just one validator.""" + + def validate_name(v: str) -> str: + return v.strip() + + class Item(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: Annotated[ + str, + Field(sa_column=Column(String(100), nullable=False, unique=True)), + AfterValidator(validate_name), + ] + + assert isinstance(Item.__table__.c.name.type, String) + assert Item.__table__.c.name.type.length == 100 + assert Item.__table__.c.name.nullable is False + assert Item.__table__.c.name.unique is True + + +def test_annotated_sa_column_without_validators() -> None: + """Test that sa_column still works with Annotated but no validators.""" + + class Record(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + created_at: Annotated[ + datetime, + Field(sa_column=Column(DateTime(timezone=True), nullable=False)), + ] + + assert Record.__table__.c.created_at.type.timezone is True + assert Record.__table__.c.created_at.nullable is False + + +def test_annotated_sa_type_with_validators() -> None: + """Test that sa_type is preserved when using Annotated with validators.""" + + def validate_timestamp(v: datetime) -> datetime: + return v + + class Event(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + occurred_at: Annotated[ + datetime, + Field(sa_type=DateTime(timezone=True)), + AfterValidator(validate_timestamp), + ] + + # Verify the column type has timezone=True + assert Event.__table__.c.occurred_at.type.timezone is True