diff --git a/pyproject.toml b/pyproject.toml index c7956daaa9..181064e4ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.7" SQLAlchemy = ">=1.4.36,<2.0.0" -pydantic = "^1.8.2" +pydantic = "^1.9.0" sqlalchemy2-stubs = {version = "*", allow-prereleases = true} [tool.poetry.group.dev.dependencies] diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 07e600e4d4..3015aa9fbd 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -145,12 +145,17 @@ def Field( lt: Optional[float] = None, le: Optional[float] = None, multiple_of: Optional[float] = None, + max_digits: Optional[int] = None, + decimal_places: Optional[int] = None, min_items: Optional[int] = None, max_items: Optional[int] = None, + unique_items: Optional[bool] = None, min_length: Optional[int] = None, max_length: Optional[int] = None, allow_mutation: bool = True, regex: Optional[str] = None, + discriminator: Optional[str] = None, + repr: bool = True, primary_key: bool = False, foreign_key: Optional[Any] = None, unique: bool = False, @@ -176,12 +181,17 @@ def Field( lt=lt, le=le, multiple_of=multiple_of, + max_digits=max_digits, + decimal_places=decimal_places, min_items=min_items, max_items=max_items, + unique_items=unique_items, min_length=min_length, max_length=max_length, allow_mutation=allow_mutation, regex=regex, + discriminator=discriminator, + repr=repr, primary_key=primary_key, foreign_key=foreign_key, unique=unique, @@ -587,7 +597,11 @@ def parse_obj( def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: # Don't show SQLAlchemy private attributes - return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")] + return [ + (k, v) + for k, v in super().__repr_args__() + if not (isinstance(k, str) and k.startswith("_sa_")) + ] # From Pydantic, override to enforce validation with dict @classmethod diff --git a/tests/test_pydantic/__init__.py b/tests/test_pydantic/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_pydantic/test_field.py b/tests/test_pydantic/test_field.py new file mode 100644 index 0000000000..9d7bc77625 --- /dev/null +++ b/tests/test_pydantic/test_field.py @@ -0,0 +1,57 @@ +from decimal import Decimal +from typing import Optional, Union + +import pytest +from pydantic import ValidationError +from sqlmodel import Field, SQLModel +from typing_extensions import Literal + + +def test_decimal(): + class Model(SQLModel): + dec: Decimal = Field(max_digits=4, decimal_places=2) + + Model(dec=Decimal("3.14")) + Model(dec=Decimal("69.42")) + + with pytest.raises(ValidationError): + Model(dec=Decimal("3.142")) + with pytest.raises(ValidationError): + Model(dec=Decimal("0.069")) + with pytest.raises(ValidationError): + Model(dec=Decimal("420")) + + +def test_discriminator(): + # Example adapted from + # [Pydantic docs](https://pydantic-docs.helpmanual.io/usage/types/#discriminated-unions-aka-tagged-unions): + + class Cat(SQLModel): + pet_type: Literal["cat"] + meows: int + + class Dog(SQLModel): + pet_type: Literal["dog"] + barks: float + + class Lizard(SQLModel): + pet_type: Literal["reptile", "lizard"] + scales: bool + + class Model(SQLModel): + pet: Union[Cat, Dog, Lizard] = Field(..., discriminator="pet_type") + n: int + + Model(pet={"pet_type": "dog", "barks": 3.14}, n=1) # type: ignore[arg-type] + + with pytest.raises(ValidationError): + Model(pet={"pet_type": "dog"}, n=1) # type: ignore[arg-type] + + +def test_repr(): + class Model(SQLModel): + id: Optional[int] = Field(primary_key=True) + foo: str = Field(repr=False) + + instance = Model(id=123, foo="bar") + assert "foo=" not in repr(instance)