diff --git a/pyproject.toml b/pyproject.toml index 37d436f657..51e6159129 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.7" SQLAlchemy = ">=2.0.0,<=2.0.11" -pydantic = "^2.1.1" +pydantic = { version = ">=2.1.1,<=2.4", extras = ["email"] } [tool.poetry.dev-dependencies] pytest = "^7.0.1" diff --git a/sqlmodel/main.py b/sqlmodel/main.py index dbc05c48c4..0f916ace99 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -29,7 +29,8 @@ ) import pydantic -from pydantic import BaseModel +from annotated_types import MaxLen +from pydantic import BaseModel, EmailStr, ImportString, NameEmail from pydantic._internal._fields import PydanticGeneralMetadata from pydantic._internal._model_construction import ModelMetaclass from pydantic._internal._repr import Representation @@ -480,7 +481,7 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any: metadata = _get_field_metadata(field) if type_ is None: raise ValueError("Missing field type") - if issubclass(type_, str): + if issubclass(type_, str) or type_ in (EmailStr, NameEmail, ImportString): max_length = getattr(metadata, "max_length", None) if max_length: return AutoString(length=max_length) @@ -693,4 +694,6 @@ def _get_field_metadata(field: FieldInfo) -> object: for meta in field.metadata: if isinstance(meta, PydanticGeneralMetadata): return meta + if isinstance(meta, MaxLen): + return meta return object() diff --git a/tests/test_pydantic_types.py b/tests/test_pydantic_types.py new file mode 100644 index 0000000000..7d2a17b3b5 --- /dev/null +++ b/tests/test_pydantic_types.py @@ -0,0 +1,24 @@ +from pydantic import EmailStr, HttpUrl, ImportString, NameEmail +from sqlmodel import Field, SQLModel, create_engine + + +def test_pydantic_types(clear_sqlmodel, caplog): + class Hero(SQLModel, table=True): + integer_primary_key: int = Field( + primary_key=True, + ) + http: HttpUrl = Field(max_length=250) + email: EmailStr + name_email: NameEmail = Field(max_length=50) + import_string: ImportString = Field(max_length=200, min_length=100) + + engine = create_engine("sqlite://", echo=True) + SQLModel.metadata.create_all(engine) + + create_table_log = [ + message for message in caplog.messages if "CREATE TABLE hero" in message + ][0] + assert "http VARCHAR(250) NOT NULL," in create_table_log + assert "email VARCHAR NOT NULL," in create_table_log + assert "name_email VARCHAR(50) NOT NULL," in create_table_log + assert "import_string VARCHAR(200) NOT NULL," in create_table_log