Skip to content

JSON Fields for Nested Pydantic Models? #63

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
8 tasks done
scuervo91 opened this issue Aug 31, 2021 · 64 comments
Open
8 tasks done

JSON Fields for Nested Pydantic Models? #63

scuervo91 opened this issue Aug 31, 2021 · 64 comments
Labels
question Further information is requested

Comments

@scuervo91
Copy link

First Check

  • I added a very descriptive title to this issue.
  • I used the GitHub search to find a similar issue and didn't find it.
  • I searched the SQLModel documentation, with the integrated search.
  • I already searched in Google "How to X in SQLModel" and didn't find any information.
  • I already read and followed all the tutorial in the docs and didn't find an answer.
  • I already checked if it is not related to SQLModel but to Pydantic.
  • I already checked if it is not related to SQLModel but to SQLAlchemy.

Commit to Help

  • I commit to help with one of those options 👆

Example Code

from tortoise.models import Model 
from tortoise.fields import UUIDField, DatetimeField,CharField, BooleanField, JSONField, ForeignKeyField, CharEnumField, IntField
from tortoise.contrib.pydantic import pydantic_model_creator

class Schedule(Model):
    id = UUIDField(pk=True)
    created_at = DatetimeField(auto_now_add=True)
    modified_at = DatetimeField(auto_now=True)
    case = JSONField()
    type = CharEnumField(SchemasEnum,description='Schedule Types')
    username = ForeignKeyField('models.Username')
    description = CharField(100)
    
schedule_pydantic = pydantic_model_creator(Schedule,name='Schedule')

Description

I have already implemented an API using FastAPI to store Pydantic Models. These models are themselves nested Pydantic models so the way they interact with a Postgres DataBase is throught JsonField. I've been using Tortoise ORM as the example shows.

Is there an equivalent model in SQLModel?

Operating System

Linux

Operating System Details

WSL 2 Ubuntu 20.04

SQLModel Version

0.0.4

Python Version

3.8

Additional Context

No response

@scuervo91 scuervo91 added the question Further information is requested label Aug 31, 2021
@OXERY
Copy link

OXERY commented Sep 3, 2021

I also wondered how to store JSON objects without converting to string. SQL Alchemy supports storing these directly

@TheJedinator
Copy link

@OXERY && @scuervo91 - I was able to get something that works Using this:

regions: dict = Field(sa_column=Column(JSON), default={'all': 'true'})

That said: this is a postgresql JSONB column in my database. But it works.

For a nested Object you could use a pydantic model as the Type and do it the same way. Hope this helps as I was having a difficult time figuring out a solution as well :)

@OXERY
Copy link

OXERY commented Sep 10, 2021

I also got it working, on SQLite and Postgresql:
mygreatfield: Dict[Any, Any] = Field(index=False, sa_column=Column(JSON))
needs from sqlmodel import Field, SQLModel, Column, JSON as well as from typing import Dict, Any

@psarka
Copy link

psarka commented Dec 1, 2021

@TheJedinator Could you help a bit more with the nested object? I tried to "use the pydantic model as the Type" but I can't get it to work :( Here is my snippet:

from sqlalchemy import Column
from sqlalchemy.dialects.postgresql import JSONB
from sqlmodel import Field
from sqlmodel import Session
from sqlmodel import SQLModel

from engine import get_sqlalchemy_engine


class J(SQLModel):
    j: int


class A(SQLModel, table=True):
    a: int = Field(primary_key=True)
    b: J = Field(sa_column=Column(JSONB))


engine = get_sqlalchemy_engine()
SQLModel.metadata.create_all(engine)

with Session(engine) as session:
    a = A(a=1, b=J(j=1))
    session.add(a)
    session.commit()
    session.refresh(a)

Throws an error

sqlalchemy.exc.StatementError: (builtins.TypeError) Object of type J is not JSON serializable
[SQL: INSERT INTO a (b, a) VALUES (%(b)s, %(a)s)]
[parameters: [{'a': 1, 'b': J(j=1)}]]

@TheJedinator
Copy link

@psarka

j = J(j=1)
db_j = J.from_orm(j)
a = A(a=1, b=db_j)

This should resolve your issue in preparing the object for the database. What I'm seeing in the error is that the Raw Object is being included in the statement rather than the instance...

If this doesn't help I can definitely put some more time in to looking at what's going on.

@psarka
Copy link

psarka commented Dec 1, 2021

Thank you! Unfortunately I get the same error :(

I found one workaround - registering a custom_serializer for the sqlalchemy engine, like so:

def custom_serializer(d):
    return json.dumps(d, default=lambda v: v.json())

def get_sqlalchemy_engine():
    return create_engine("postgresql+psycopg2://", creator=get_conn, json_serializer=custom_serializer)

But if there is a cleaner way, I would gladly use that instead.

@TheJedinator
Copy link

Hey @psarka

I just actually tried what I told and sorry have mislead... I did get a working solution though 😄

It was actually the opposite function that you need to use, here's the example you supplied with the amendments to make it work:

with Session(engine) as session:
    j = J(j=1)
    j_dumped = J.json(j)
    a = A(a=1, b=j_dumped)
    session.add(a)
    session.commit()
    session.refresh(a)

@psarka
Copy link

psarka commented Dec 2, 2021

Hmm, this doesn't (or at least shouldn't) typecheck :)

But I see what you did there, essentially it's the same as registring a custom serializer, but manually.

@TheJedinator
Copy link

It does type check when you create the J Object (which it should) So if you tried to supply a string it would fail J(j="foo")

This allows for the type checking of the object, the A class requires a serialized version of J in order for it to be entered in to the database.

It is essentially the same as registering a custom serializer but allows you to be explicit about using it.

@HenningScheufler
Copy link

A hacky method with type checking that work with sqlite is

from sqlalchemy import Column
from typing import List
# from sqlalchemy.dialects.postgresql import JSONB
from sqlmodel import Field
from sqlmodel import Session
from pydantic import validator
from sqlmodel import SQLModel, JSON,create_engine

# from engine import get_sqlalchemy_engine
sqlite_file_name = "test.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"

engine = create_engine(sqlite_url)


class J2(SQLModel):
    test: List[int]

class J(SQLModel):
    j: int
    nested: J2


class A(SQLModel, table=True):
    a: int = Field(primary_key=True)
    b: J = Field(sa_column=Column(JSON))

    @validator('b')
    def val_b(cls, val):
        return val.dict()

SQLModel.metadata.create_all(engine)

with Session(engine) as session:
    a = A(a=1, b=J(j=1,nested=J2(test=[100,100,100])))
    session.add(a)
    session.commit()
    session.refresh(a)

@hakanoktay
Copy link

hakanoktay commented Feb 10, 2022

hi,
I created a "JSON Field" based on what is written here. I am using SQLite.

from sqlmodel import SQLModel,Relationship,Field,JSON
from typing import Optional,List, Dict
from sqlalchemy import Column
from pydantic import validator


#
class J2(SQLModel):
    id: int
    title:str

#
class Companies(SQLModel, table=True):
    id:Optional[int]=Field(default=None,primary_key=True)
    name:str
    adddresses: List['J2'] = Field(sa_column=Column(JSON))


    @validator('adddresses')
    def val_b(cls, val):
        print(val)
        return val.dict()

Given error.

TypeError: Type is not JSON serializable: J2

when i print it, it returns

[J2(id=1, title='address1'), J2(id=2, title='address2')]

how can i handle that? Why is this J2 added, how can I get rid of it, i can't turn it to .dict(), i cannot serialise it... can you give an idea?

@HenningScheufler
Copy link

Does this work?

    @validator('adddresses')
    def val_b(cls, value):
        print(value)
        return [v.dict() for v in value]

@hakanoktay
Copy link

Does this work?

    @validator('adddresses')
    def val_b(cls, value):
        print(value)
        return [v.dict() for v in value]

@HenningScheufler thank you for your help, it worked perfect.

@MaximilianFranz
Copy link

Hey all,

thanks for the great advice here. Creating a the object using the classes and writing them to the DB works as expected and writes the data as a dict into a JSON field.

See this example:

class ComplexHeroField(SQLModel, table=False):
    some: str
    other: float
    more: Optional[List[str]]

class Hero(SQLModel, table=True):
    id: Optional[int] = Field(default=None, primary_key=True)
    complex_field: ComplexHeroField = Field(sa_column=Column(JSON))
    name: str
    secret_name: str
    age: Optional[int] = None

    @validator('complex_field')
    def val_complex(cls, val: ComplexHeroField):
        # Used in order to store pydantic models as dicts
        return val.dict()

    class Config:
        arbitrary_types_allowed = True

However, when reading the model from the DB using a select() I would want the JSON field to be read into a ComplexHeroField class using pydantics parse_raw or parse_obj. Because they way it's currently done (with the validator) this happens:

        statement = select(Hero)
        results = session.exec(statement)
        for hero in results:
            print(hero.complex_field.some)

       # AttributeError: 'dict' object has no attribute 'some'

Any hint how that could be achieved? Maybe via the custom-serialiser mentioned by @psarka ?

Thanks already!

@MaximilianFranz
Copy link

MaximilianFranz commented Mar 16, 2022

Something like this works, but obviously doesn't scale if we have mulitple nested models, instead of just the ComplexHeroField:


def custom_serializer(d):
    return json.dumps(d, default=lambda v: v.json())

def custom_deserialiser(d):
    return ComplexHeroField.parse_raw(d)

engine = create_engine(url_string, echo=True, json_serializer=custom_serializer, json_deserializer=custom_deserialiser)

complex_value = ComplexHeroField(some="value", other=5, more=["dd", "sdf"])
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson", complex_field=complex_value)
session.add(hero_1)
session.commit()

statement = select(Hero)
results = session.exec(statement)
for hero in results:
    print(hero.complex_field.some)
    # value 

Instead, we would need more context in the deserialiser (i.e. access to the type-hint of the field we're trying to deserialise so that we can use UseType.parse_raw().

Any hint where and how I could achieve that kind of access to the deserialisation process?

Thanks :)

@MaximilianFranz
Copy link

Hey all,

after looking at this again, I've been able to resolve it as follows.

For our sqlalchemy models we created this PydanticJSONType factory:

def pydantic_column_type(pydantic_type):
    class PydanticJSONType(TypeDecorator, Generic[T]):
        impl = JSON()

        def __init__(
            self, json_encoder=json,
        ):
            self.json_encoder = json_encoder
            super(PydanticJSONType, self).__init__()

        def bind_processor(self, dialect):
            impl_processor = self.impl.bind_processor(dialect)
            dumps = self.json_encoder.dumps
            if impl_processor:

                def process(value: T):
                    if value is not None:
                        if isinstance(pydantic_type, ModelMetaclass):
                            # This allows to assign non-InDB models and if they're
                            # compatible, they're directly parsed into the InDB
                            # representation, thus hiding the implementation in the
                            # background. However, the InDB model will still be returned
                            value_to_dump = pydantic_type.from_orm(value)
                        else:
                            value_to_dump = value
                        value = recursive_custom_encoder(value_to_dump)
                    return impl_processor(value)

            else:

                def process(value):
                    if isinstance(pydantic_type, ModelMetaclass):
                        # This allows to assign non-InDB models and if they're
                        # compatible, they're directly parsed into the InDB
                        # representation, thus hiding the implementation in the
                        # background. However, the InDB model will still be returned
                        value_to_dump = pydantic_type.from_orm(value)
                    else:
                        value_to_dump = value
                    value = dumps(recursive_custom_encoder(value_to_dump))
                    return value

            return process

        def result_processor(self, dialect, coltype) -> T:
            impl_processor = self.impl.result_processor(dialect, coltype)
            if impl_processor:

                def process(value):
                    value = impl_processor(value)
                    if value is None:
                        return None

                    data = value
                    # Explicitly use the generic directly, not type(T)
                    full_obj = parse_obj_as(pydantic_type, data)
                    return full_obj

            else:

                def process(value):
                    if value is None:
                        return None

                    # Explicitly use the generic directly, not type(T)
                    full_obj = parse_obj_as(pydantic_type, value)
                    return full_obj

            return process

        def compare_values(self, x, y):
            return x == y

    return PydanticJSONType

where recursive_custom_encoder() is pretty much the fastAPI jsonable_encoder

Using this in SQLModel as follows:

class ConnectionResistances(SQLConnectionModel, table=False):
    very_short: ResistancesInLoadDuration = ResistancesInLoadDuration()
    short: ResistancesInLoadDuration = ResistancesInLoadDuration()
    middle: ResistancesInLoadDuration = ResistancesInLoadDuration()
    long: ResistancesInLoadDuration = ResistancesInLoadDuration()
    constant: ResistancesInLoadDuration = ResistancesInLoadDuration()
    earth_quake: ResistancesInLoadDuration = ResistancesInLoadDuration()

class Connection(SQLConnectionModel, table=True):

    id: Optional[uuid.UUID] = Field(default=None, sa_column=Column(PGUUID(as_uuid=True), default=uuid.uuid4, primary_key=True))
    name: str
    comment: str
    path_to_pdf: Optional[str] = None
    resistance_values: ConnectionResistances = Field(..., sa_column=Column(pydantic_column_type(ConnectionResistances)))

Works perfectly!
That means:

  • In the DB data is stored as JSON
  • whenever the model is read from the DB, data is read into the pydantic class (including validation)
  • whenever the the model is written into DB, the data is transformed into JSON

This could be integrated into an sqlmodel api based on the type hint alone (i.e. creating the sa_column based on the pydantic type automatically). Potentially in get_sqlachemy_type.

What do you think, @tiangolo?

@tchaton
Copy link

tchaton commented Aug 3, 2022

@tiangolo Any updates ?

@tchaton
Copy link

tchaton commented Aug 3, 2022

Hey @MaximilianFranz Would you mind sharing your entire solution, I am quite interested in trying it out, but it is missing some code pieces.

@MaximilianFranz
Copy link

Hey @MaximilianFranz Would you mind sharing your entire solution, I am quite interested in trying it out, but it is missing some code pieces.

What exactly are you missing? Happy to provide more context!

@tchaton
Copy link

tchaton commented Aug 3, 2022

The recursive_custom_encoder is missing. Ideally, a fully working example I can simply copy/paste and adapt to my use case ;)

@MaximilianFranz
Copy link

MaximilianFranz commented Aug 3, 2022

You can use jsonable_encode like such, instead of the recursive_custom_encoder

from fastapi.encoders import jsonable_encoder

also I would start with a simpler model like:

class NestedModel(SQLModel):
    some_value: str

class OuterModel(SQLModel, table=True):
    guid: str = Field(
        default=None,
        sa_column=Column(PGUUID(as_uuid=True), default=uuid.uuid4, primary_key=True),
    )
    nested: NestedModel = Field(..., sa_column=Column(pydantic_column_type(NestedModel)))

That should work!

@tchaton
Copy link

tchaton commented Aug 3, 2022

Thanks, @MaximilianFranz Let me try. My code is here: https://github.com/Lightning-AI/lightning-hpo/blob/master/lightning_hpo/commands/sweep.py#L36. Trying to store the Sweep distributions.
Do you think it would work with the recursion?
Missing parse_obj_as and ModelMetaclass.

@tchaton
Copy link

tchaton commented Aug 3, 2022

Hey @MaximilianFranz

I have made a draft PR there: https://github.com/Lightning-AI/lightning-hpo/pull/19/files. I tried but it is raising an error. Would you mind having a look?

Best,
T.C

@MaximilianFranz
Copy link

MaximilianFranz commented Aug 3, 2022

Thanks, @MaximilianFranz Let me try. My code is here: https://github.com/Lightning-AI/lightning-hpo/blob/master/lightning_hpo/commands/sweep.py#L36. Trying to store the Sweep distributions. Do you think it would work with the recursion? Missing parse_obj_as and ModelMetaclass.

Both parse_obj_as and ModelMetaClass can be imported from pydantic:

from pydantic import parse_obj_as
from pydantic.main import ModelMetaclass

As for the error, would you mind pointing me to the action that fails or post a traceback somewhere?

@MaximilianFranz
Copy link

It makes sense that it doesn't work yet. You'll have to use the ModelMetaclass as is done in my snippet above for the isinstance check. Also the import for parse_obj_as is missing, so it can't work as it is :)

@tchaton
Copy link

tchaton commented Aug 3, 2022

Hey @MaximilianFranz, I updated the code with your inputs, but it is still failing. I pushed the updated code.

  File "/Users/thomas/Documents/GitHub/LAI-lightning-hpo-App/lightning_hpo/components/servers/db/server.py", line 42, in insert_sweep
    session.commit()
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/orm/session.py", line 1451, in commit
    self._transaction.commit(_to_root=self.future)
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/orm/session.py", line 829, in commit
    self._prepare_impl()
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/orm/session.py", line 808, in _prepare_impl
    self.session.flush()
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/orm/session.py", line 3383, in flush
    self._flush(objects)
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/orm/session.py", line 3523, in _flush
    transaction.rollback(_capture_exception=True)
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/util/langhelpers.py", line 70, in __exit__
    compat.raise_(
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/util/compat.py", line 208, in raise_
    raise exception
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/orm/session.py", line 3483, in _flush
    flush_context.execute()
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/orm/unitofwork.py", line 456, in execute
    rec.execute(self)
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/orm/unitofwork.py", line 630, in execute
    util.preloaded.orm_persistence.save_obj(
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/orm/persistence.py", line 245, in save_obj
    _emit_insert_statements(
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/orm/persistence.py", line 1238, in _emit_insert_statements
    result = connection._execute_20(
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/engine/base.py", line 1631, in _execute_20
    return meth(self, args_10style, kwargs_10style, execution_options)
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/sql/elements.py", line 332, in _execute_on_connection
    return connection._execute_clauseelement(
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/engine/base.py", line 1498, in _execute_clauseelement
    ret = self._execute_context(
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/engine/base.py", line 1862, in _execute_context
    self._handle_dbapi_exception(
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/engine/base.py", line 2043, in _handle_dbapi_exception
    util.raise_(
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/util/compat.py", line 208, in raise_
    raise exception
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/engine/base.py", line 1819, in _execute_context
    self.dialect.do_execute(
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/engine/default.py", line 732, in do_execute
    cursor.execute(statement, parameters)
sqlalchemy.exc.InterfaceError: (sqlite3.InterfaceError) Error binding parameter 5 - probably unsupported type.
[SQL: INSERT INTO sweepconfig (distributions, sweep_id, script_path, n_trials, simultaneous_trials, requirements, script_args, framework, cloud_compute, num_nodes, logger, direction) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)]
[parameters: ('{"name": "model.lr", "distribution": "uniform", "params": {"params": {"low": "0.001", "high": "0.1"}}}', 'thomas-5e0dd935', 'train.py', 1, 1, [], [], 'pytorch_lightning', 'cpu', 1, 'wandb', 'maximize')]
(Background on this error at: https://sqlalche.me/e/14/rvf5)

@MaximilianFranz
Copy link

Hey @MaximilianFranz, I updated the code with your inputs, but it is still failing. I pushed the updated code.

  File "/Users/thomas/Documents/GitHub/LAI-lightning-hpo-App/lightning_hpo/components/servers/db/server.py", line 42, in insert_sweep
    session.commit()
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/orm/session.py", line 1451, in commit
    self._transaction.commit(_to_root=self.future)
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/orm/session.py", line 829, in commit
    self._prepare_impl()
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/orm/session.py", line 808, in _prepare_impl
    self.session.flush()
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/orm/session.py", line 3383, in flush
    self._flush(objects)
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/orm/session.py", line 3523, in _flush
    transaction.rollback(_capture_exception=True)
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/util/langhelpers.py", line 70, in __exit__
    compat.raise_(
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/util/compat.py", line 208, in raise_
    raise exception
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/orm/session.py", line 3483, in _flush
    flush_context.execute()
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/orm/unitofwork.py", line 456, in execute
    rec.execute(self)
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/orm/unitofwork.py", line 630, in execute
    util.preloaded.orm_persistence.save_obj(
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/orm/persistence.py", line 245, in save_obj
    _emit_insert_statements(
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/orm/persistence.py", line 1238, in _emit_insert_statements
    result = connection._execute_20(
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/engine/base.py", line 1631, in _execute_20
    return meth(self, args_10style, kwargs_10style, execution_options)
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/sql/elements.py", line 332, in _execute_on_connection
    return connection._execute_clauseelement(
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/engine/base.py", line 1498, in _execute_clauseelement
    ret = self._execute_context(
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/engine/base.py", line 1862, in _execute_context
    self._handle_dbapi_exception(
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/engine/base.py", line 2043, in _handle_dbapi_exception
    util.raise_(
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/util/compat.py", line 208, in raise_
    raise exception
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/engine/base.py", line 1819, in _execute_context
    self.dialect.do_execute(
  File "/Users/thomas/Documents/GitHub/lightning/.venv/lib/python3.8/site-packages/sqlalchemy/engine/default.py", line 732, in do_execute
    cursor.execute(statement, parameters)
sqlalchemy.exc.InterfaceError: (sqlite3.InterfaceError) Error binding parameter 5 - probably unsupported type.
[SQL: INSERT INTO sweepconfig (distributions, sweep_id, script_path, n_trials, simultaneous_trials, requirements, script_args, framework, cloud_compute, num_nodes, logger, direction) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)]
[parameters: ('{"name": "model.lr", "distribution": "uniform", "params": {"params": {"low": "0.001", "high": "0.1"}}}', 'thomas-5e0dd935', 'train.py', 1, 1, [], [], 'pytorch_lightning', 'cpu', 1, 'wandb', 'maximize')]
(Background on this error at: https://sqlalche.me/e/14/rvf5)

To finish this, the problem ended up being a attribute of type List[str] on an SQLModel, which is not natively supported. Using the above pydantic_column_type with List[str] works however and will encode the list as json-string in order to store it to the database.

@felipemonroy
Copy link

felipemonroy commented Aug 12, 2024

I was able to solve this issue using the json-fix library. You only need to add the following method to your nested models.

    def __json__(self):
        return self.model_dump()

I hope this can be done without an external library.

@alexdashly
Copy link

alexdashly commented Sep 18, 2024

This is my approach using a conventional SQLAlchemy approach. This is actually how others have done this with SA and Pydantic, for instance in this discussion: sqlalchemy/sqlalchemy#11050

from typing import Any, Self

from pydantic import BaseModel as _BaseModel
from sqlalchemy import JSON, types, Column
from sqlalchemy.ext.mutable import Mutable
from sqlmodel import SQLModel, Field


class JsonPydanticField(types.TypeDecorator):
    impl = JSON

    def __init__(self, pydantic_model):
        super().__init__()
        self.pydantic_model = pydantic_model

    def load_dialect_impl(self, dialect):
        return dialect.type_descriptor(JSON())

    def process_bind_param(self, value: _BaseModel, _):
        return value.model_dump() if value is not None else None

    def process_result_value(self, value, _):
        return self.pydantic_model.model_validate(value) if value is not None else None


class MutableSABaseModel(_BaseModel, Mutable):

    def __setattr__(self, name: str, value: Any) -> None:
        """Allows SQLAlchmey Session to track mutable behavior"""
        self.changed()
        return super().__setattr__(name, value)

    @classmethod
    def coerce(cls, key: str, value: Any) -> Self | None:
        """Convert JSON to pydantic model object allowing for mutable behavior"""
        if isinstance(value, cls) or value is None:
            return value

        if isinstance(value, str):
            return cls.model_validate_json(value)

        if isinstance(value, dict):
            return cls(**value)

        return super().coerce(key, value)

    @classmethod
    def to_sa_type(cls):
        return cls.as_mutable(JsonPydanticField(cls))


class Nested(MutableSABaseModel):
    a: str
    b: str | None = None


NestedSAType = Nested.to_sa_type()


class DBModel(SQLModel, table=True):

    id: str = Field(primary_key=True)
    nested: Nested = Field(sa_column=Column(NestedSAType))

@TechLipefi
Copy link

Hey there,
@tiangolo @estebanx64 could you pls make it clear, will you include it in some point in roadmap or should we come up with some workaround as @alexdashly proposed?
We also adopted SQAlchemy Mutable type, and it is working good (a bit mess with typing but it's fine)

@dadodimauro
Copy link

dadodimauro commented Oct 30, 2024

I had a similar problem with a list of a nested pydantic model:

from sqlmodel import JSON, Column, Field, SQLModel, create_engine

class MyNestedModel(SQLModel):
    a: str
    b: str | None

class MyModel(SQLModel):
    c: list[MyNestedModel] | None = Field(
        default=None, sa_column=Column(JSON)
    )

i solved passing to the create_engine() function the following serializer:

def serialize_pydantic_model(model: BaseModel | list[BaseModel] | None) -> str | None:
    if isinstance(model, BaseModel):
        return model.model_dump_json()
    if isinstance(model, list):
        return json.dumps([m.model_dump_json() for m in model])
    return model

I know is not an elegant solution but was enough to make it work.

@dataengineeringatfunderzgroup
Copy link

dataengineeringatfunderzgroup commented Oct 31, 2024

I had a similar problem with a list of a nested pydantic model:

from sqlmodel import JSON, Column, Field, SQLModel, create_engine

class MyNestedModel(SQLModel):
    a: str
    b: str | None

class MyModel(SQLModel):
    c: list[MyNestedModel] | None = Field(
        default=None, sa_column=Column(JSON)
    )

i solved passing to the create_engine() function the following serializer:

def serialize_pydantic_model(model: BaseModel | list[BaseModel] | None) -> str | None:
    if isinstance(model, BaseModel):
        return model.model_dump_json()
    if isinstance(model, list):
        return json.dumps([m.model_dump_json() for m in model])
    return model

I know is not an elegant solution but was enough to make it work.

Do you have the complete code @dadodimauro ? I'm trying to use it.

@dataengineeringatfunderzgroup

I HAVE SOLVED!

I have created a List() field using jsonable_encoder

here is my snipped code.

SOLUTION

from fastapi.encoders import jsonable_encoder
from sqlmodel import Field, Session, SQLModel, create_engine, select, JSON, Column

class SKU(SQLModel):
    sku_item_number: str
    quantity: int
    sku_price: Decimal

class SalesBase(SQLModel):
    business_id: str = Field(index=True, description='Unique identifier for the business')
    date: datetime = Field(description='Date of the sales transaction | Format: YYYY-MM-DD')
    sale_amount: Decimal = Field(description='Total amount of sales in a given transaction')
    sale_count: int = Field(description='The number of sale transactions')
    reversal_amount: Optional[Decimal] = Field(description='Total amount of reversals (refunds) within the transaction')
    reversal_count: Optional[int] = Field(description='The number of reversals within the transaction')
    currency: str = Field(description='The currency of the transaction (e.g., USD, EUR)', default='USD')
    skus: list[SKU] | None = Field(
        default=None,
        sa_column=Column(JSON)
    )

    # Needed for Column(JSON)
    class Config:
        arbitrary_types_allowed = True

class Tb_Sales(SalesBase, table=True):
    id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)


@router.post("/add")
async def add(item: SalesBase):
    tb_sales = Tb_Sales()

    with Session(engine) as session:

        # tb_sales
        tb_sales.business_id = item.business_id
        tb_sales.date = item.date
        tb_sales.sale_amount = item.sale_amount
        tb_sales.sale_count = item.sale_count
        tb_sales.reversal_amount = item.reversal_amount
        tb_sales.reversal_count = item.reversal_count
        tb_sales.currency = str(item.currency).upper()
        tb_sales.skus = jsonable_encoder(item.skus)
        session.add(tb_sales)

        # commit
        session.commit()
        session.refresh(tb_sales)

    out = {
        "message": "item created!",
        "id": tb_sales.id,
        "details": tb_sales
    }

    return out

Output data from data saved on Postgres

[
  {
    "sale_count": 10,
    "business_id": "999",
    "reversal_count": 10,
    "skus": null,
    "reversal_amount": "10",
    "date": "2024-10-31T01:42:58.469000",
    "sale_amount": "10",
    "currency": "USD",
    "id": "f10a3faf-ce80-46ea-a07c-ef7787288ad9"
  },
  {
    "sale_count": 10,
    "business_id": "999",
    "reversal_count": 10,
    "skus": [
      {
        "sku_item_number": "111",
        "quantity": 10,
        "sku_price": "10"
      }
    ],
    "reversal_amount": "10",
    "date": "2024-10-31T01:42:58.469000",
    "sale_amount": "10",
    "currency": "USD",
    "id": "2b4a04fd-8703-4e1b-a6cf-1c6b2a91d3cc"
  },
  {
    "sale_count": 20,
    "business_id": "string",
    "reversal_count": 20,
    "skus": [
      {
        "sku_item_number": "999",
        "quantity": 20,
        "sku_price": "20"
      }
    ],
    "reversal_amount": "20",
    "date": "2024-10-31T02:04:09.032000",
    "sale_amount": "20",
    "currency": "USD",
    "id": "0b76d570-83aa-48e2-8267-d39549da5fc5"
  },
  {
    "sale_count": 0,
    "business_id": "string",
    "reversal_count": 0,
    "skus": [
      {
        "sku_item_number": "string",
        "quantity": 0,
        "sku_price": "0"
      }
    ],
    "reversal_amount": "0",
    "date": "2024-10-31T02:06:38.215000",
    "sale_amount": "0",
    "currency": "USD",
    "id": "f9ae39a1-3bed-4af8-b1ed-8501d41ffa32"
  },
  {
    "sale_count": 0,
    "business_id": "string",
    "reversal_count": 0,
    "skus": [
      {
        "sku_item_number": "1",
        "quantity": 10,
        "sku_price": "10"
      },
      {
        "sku_item_number": "2",
        "quantity": 20,
        "sku_price": "20"
      },
      {
        "sku_item_number": "3",
        "quantity": 30,
        "sku_price": "30"
      }
    ],
    "reversal_amount": "0",
    "date": "2024-10-31T02:06:38.215000",
    "sale_amount": "0",
    "currency": "USD",
    "id": "f1c59936-0a8a-4134-a642-4901e75d563c"
  }
]

@gray-adeyi
Copy link

Hi! I experienced a similar issue, so I had to use @alexdashly 's solution. The first issue i experienced with the solution was that the resulting data in the database was stored as a json string even when the model that was dumped was an array which limited me from some types of queries such as fetching from the database a bunch of items based on the value in a json field, so, i had to modify it to

class JSONBPydanticField(types.TypeDecorator):
    """This is a custom SQLAlchemy field that allows easy serialization between database JSONB types and Pydantic models"""

    impl = JSONB

    def __init__(
        self,
        pydantic_model_class: type["MutableSABaseModel"],
        many: bool = False,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.pydantic_model_class = pydantic_model_class
        self.many = many

    def load_dialect_impl(self, dialect):
        return dialect.type_descriptor(JSONB())

    def process_bind_param(self, value: _BaseModel | list[_BaseModel], dialect):
        """Convert python native type to JSON string before storing in the database"""
        return jsonable_encoder(value) if value else None

    def process_result_value(self, value, dialect):
        """Convert JSON string back to Python object after retrieving from the database"""
        if self.many:
            return (
                [self.pydantic_model_class.model_validate(v) for v in value]
                if value
                else None
            )
        return (
            self.pydantic_model_class.model_validate(value)
            if value is not None
            else None
        )

class MutableSAList(list, Mutable):
    """This is a hack that is intended to allow SQLAlchemy detect changes in JSON field that is a list in native python
    Allows SQLAlchmey Session to track mutable behavior"""

    @override
    def append(self, __object):
        self.changed()
        super().append(__object)

    @override
    def remove(self, __value):
        self.changed()
        super().remove(__value)

    @override
    def pop(self, __index=-1):
        self.changed()
        super().pop(__index)

    @override
    def reverse(self):
        self.changed()
        super().reverse()

    @override
    def __setattr__(self, name: str, value: Any) -> None:
        self.changed()
        super().__setattr__(name, value)

    @override
    def __setitem__(self, key, value):
        self.changed()
        super().__setitem__(key, value)

    @override
    def __delitem__(self, key):
        self.changed()
        super().__delitem__(key)

    def __iadd__(self, other):
        self.changed()
        super().__iadd__(other)


class MutableSABaseModel(_BaseModel, Mutable):
    """This is a hack that is intended to allow SQLAlchemy detect changes in JSON field that is a pydantic model"""

    def __setattr__(self, name: str, value: Any) -> None:
        """Allows SQLAlchmey Session to track mutable behavior"""
        self.changed()
        return super().__setattr__(name, value)

    @classmethod
    def coerce(cls, key: str, value: Any) -> Self | None:
        """Convert JSON to pydantic model object allowing for mutable behavior"""
        if isinstance(value, cls) or value is None:
            return value

        if isinstance(value, str):
            return cls.model_validate_json(value)

        if isinstance(value, dict):
            return cls.model_validate(value)

        if isinstance(value, list):
            return MutableSAList([cls.model_validate(v) for v in value])

        return super().coerce(key, value)

    @classmethod
    def to_sa_type(cls, many=False):
        return cls.as_mutable(JSONBPydanticField(pydantic_model_class=cls, many=many))


class BaseDBModel(SQLModel):
    id: UUID = Field(default_factory=uuid4, primary_key=True)
    created_at: AwareDatetime = Field(
        default_factory=aware_datetime_now, sa_type=TIMESTAMP(timezone=True)
    )
    last_updated_at: AwareDatetime | None = Field(sa_type=TIMESTAMP(timezone=True))

    objects: ClassVar[BaseModelManager] = BaseModelManager()

class OrganizationMemberPermission(str, Enum):
    MANAGE_EVENTS = "EVENT:WRITE"
    INVITE_MEMBERS = "MEMBERS:INVITE"
    APPROVE_REQUESTS = "MEMBERS:APPROVE_REQUEST"


class OrganizationMember(MutableSABaseModel):
    id: UUID
    role: str
    permissions: list[OrganizationMemberPermission] = Field(
        description="A list of administrative features an organization member can perform in the organization"
    )


OrganizationMembersSAType = OrganizationMember.to_sa_type(many=True)


class Organization(BaseDBModel, table=True):
    __tablename__ = "organizations"

    name: str = Field(max_length=128, unique=True)
    is_verified: bool = Field(
        False,
        description="used to flag organizations that has been verified by eventtrakka",
    )
    logo_url: str | None = Field(None)
    about: str | None
    owner_id: UUID = Field(foreign_key="users.id")
    owner: "User" = Relationship()
    members: list[OrganizationMember] = Field(
        default_factory=list,
        sa_type=OrganizationMembersSAType,
    )

    objects: ClassVar[OrganizationModelManager["Organization"]] = (
        OrganizationModelManager()
    )

The MutableSAList doesn't look great but that was what i had to fall back to after too many back and forth with chat gpt. so all json fields pydantic types had to inherit from MutableSABaseModel and I was able to perform this query. where I'm filtering organizations based on if a user is a member.

async def get_organizations_as_member(
        self,
        member: "User",
        session: AsyncSession | None = None,
    ) -> list[T]:
        async for s in get_db_session():
            session = s or session
            query = (
                select(self.model_class)
                .select_from(self.model_class)
                .join(
                    func.jsonb_array_elements(self.model_class.members).alias(
                        "members_jsonb"
                    ),
                    text("true"),  # LATERAL join
                )
                .where(
                    func.jsonb_extract_path_text(column("members_jsonb"), "id")
                    == str(member.id)
                )
            )
            return await paginate(session, query)

My current limitation is the way alembic auto generates the migrations, so i have to modify the migrations files.

# from this
sa.Column(
            "members",
            JSONBPydanticField(
                astext_type=Text(),
            ),
            nullable=False,
        ),

# to this
sa.Column(
            "members",
            JSONBPydanticField(
                pydantic_model_class=OrganizationMember,
                many=True,
                astext_type=sa.Text(),
            ),
            nullable=False,
        ),

The entire code is at https://github.com/OSCA-Ado-Ekiti/EventTrakka-Backend .

@tcztzy
Copy link

tcztzy commented Nov 19, 2024

Hi! I experienced a similar issue, so I had to use @alexdashly 's solution. The first issue i experienced with the solution was that the resulting data in the database was stored as a json string even when the model that was dumped was an array which limited me from some types of queries such as fetching from the database a bunch of items based on the value in a json field, so, i had to modify it to

class JSONBPydanticField(types.TypeDecorator):
    """This is a custom SQLAlchemy field that allows easy serialization between database JSONB types and Pydantic models"""

    impl = JSONB

    def __init__(
        self,
        pydantic_model_class: type["MutableSABaseModel"],
        many: bool = False,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.pydantic_model_class = pydantic_model_class
        self.many = many

    def load_dialect_impl(self, dialect):
        return dialect.type_descriptor(JSONB())

    def process_bind_param(self, value: _BaseModel | list[_BaseModel], dialect):
        """Convert python native type to JSON string before storing in the database"""
        return jsonable_encoder(value) if value else None

    def process_result_value(self, value, dialect):
        """Convert JSON string back to Python object after retrieving from the database"""
        if self.many:
            return (
                [self.pydantic_model_class.model_validate(v) for v in value]
                if value
                else None
            )
        return (
            self.pydantic_model_class.model_validate(value)
            if value is not None
            else None
        )

class MutableSAList(list, Mutable):
    """This is a hack that is intended to allow SQLAlchemy detect changes in JSON field that is a list in native python
    Allows SQLAlchmey Session to track mutable behavior"""

    @override
    def append(self, __object):
        self.changed()
        super().append(__object)

    @override
    def remove(self, __value):
        self.changed()
        super().remove(__value)

    @override
    def pop(self, __index=-1):
        self.changed()
        super().pop(__index)

    @override
    def reverse(self):
        self.changed()
        super().reverse()

    @override
    def __setattr__(self, name: str, value: Any) -> None:
        self.changed()
        super().__setattr__(name, value)

    @override
    def __setitem__(self, key, value):
        self.changed()
        super().__setitem__(key, value)

    @override
    def __delitem__(self, key):
        self.changed()
        super().__delitem__(key)

    def __iadd__(self, other):
        self.changed()
        super().__iadd__(other)


class MutableSABaseModel(_BaseModel, Mutable):
    """This is a hack that is intended to allow SQLAlchemy detect changes in JSON field that is a pydantic model"""

    def __setattr__(self, name: str, value: Any) -> None:
        """Allows SQLAlchmey Session to track mutable behavior"""
        self.changed()
        return super().__setattr__(name, value)

    @classmethod
    def coerce(cls, key: str, value: Any) -> Self | None:
        """Convert JSON to pydantic model object allowing for mutable behavior"""
        if isinstance(value, cls) or value is None:
            return value

        if isinstance(value, str):
            return cls.model_validate_json(value)

        if isinstance(value, dict):
            return cls.model_validate(value)

        if isinstance(value, list):
            return MutableSAList([cls.model_validate(v) for v in value])

        return super().coerce(key, value)

    @classmethod
    def to_sa_type(cls, many=False):
        return cls.as_mutable(JSONBPydanticField(pydantic_model_class=cls, many=many))


class BaseDBModel(SQLModel):
    id: UUID = Field(default_factory=uuid4, primary_key=True)
    created_at: AwareDatetime = Field(
        default_factory=aware_datetime_now, sa_type=TIMESTAMP(timezone=True)
    )
    last_updated_at: AwareDatetime | None = Field(sa_type=TIMESTAMP(timezone=True))

    objects: ClassVar[BaseModelManager] = BaseModelManager()

class OrganizationMemberPermission(str, Enum):
    MANAGE_EVENTS = "EVENT:WRITE"
    INVITE_MEMBERS = "MEMBERS:INVITE"
    APPROVE_REQUESTS = "MEMBERS:APPROVE_REQUEST"


class OrganizationMember(MutableSABaseModel):
    id: UUID
    role: str
    permissions: list[OrganizationMemberPermission] = Field(
        description="A list of administrative features an organization member can perform in the organization"
    )


OrganizationMembersSAType = OrganizationMember.to_sa_type(many=True)


class Organization(BaseDBModel, table=True):
    __tablename__ = "organizations"

    name: str = Field(max_length=128, unique=True)
    is_verified: bool = Field(
        False,
        description="used to flag organizations that has been verified by eventtrakka",
    )
    logo_url: str | None = Field(None)
    about: str | None
    owner_id: UUID = Field(foreign_key="users.id")
    owner: "User" = Relationship()
    members: list[OrganizationMember] = Field(
        default_factory=list,
        sa_type=OrganizationMembersSAType,
    )

    objects: ClassVar[OrganizationModelManager["Organization"]] = (
        OrganizationModelManager()
    )

The MutableSAList doesn't look great but that was what i had to fall back to after too many back and forth with chat gpt. so all json fields pydantic types had to inherit from MutableSABaseModel and I was able to perform this query. where I'm filtering organizations based on if a user is a member.

async def get_organizations_as_member(
        self,
        member: "User",
        session: AsyncSession | None = None,
    ) -> list[T]:
        async for s in get_db_session():
            session = s or session
            query = (
                select(self.model_class)
                .select_from(self.model_class)
                .join(
                    func.jsonb_array_elements(self.model_class.members).alias(
                        "members_jsonb"
                    ),
                    text("true"),  # LATERAL join
                )
                .where(
                    func.jsonb_extract_path_text(column("members_jsonb"), "id")
                    == str(member.id)
                )
            )
            return await paginate(session, query)

My current limitation is the way alembic auto generates the migrations, so i have to modify the migrations files.

# from this
sa.Column(
            "members",
            JSONBPydanticField(
                astext_type=Text(),
            ),
            nullable=False,
        ),

# to this
sa.Column(
            "members",
            JSONBPydanticField(
                pydantic_model_class=OrganizationMember,
                many=True,
                astext_type=sa.Text(),
            ),
            nullable=False,
        ),

The entire code is at https://github.com/OSCA-Ado-Ekiti/EventTrakka-Backend .

You can omit re-define the MutableList by using sqlalchemy.ext.mutable.MutableList

@IWillChangeTheNameLater

This might not be the perfect solution, but it worked for me. It's easy to implement, passes MyPy checks in strict mode, and allows you to conveniently configure serialization and model validation. It probably won't work with types like list[Custom], but I think that's easy to fix.

from sqlmodel import SQLModel, Field, JSON
from functools import partial
from pydantic import BaseModel, BeforeValidator, AfterValidator
from typing import Any, Annotated


def convert_dict_to_model(
        data: dict[str, Any]|BaseModel, *, model: type[BaseModel]
) -> BaseModel:
    if isinstance(data, BaseModel):
        return data
    return model(**data)


def convert_dict_to_model_validator_factory(
        model: type[BaseModel]
) -> BeforeValidator:
    return BeforeValidator(partial(convert_dict_to_model, model=model))


def convert_model_to_dict(model: BaseModel|dict) -> dict[str, Any]:
    if isinstance(model, dict):
        return model
    return model.model_dump(
        mode='json',
        exclude_none=True,
        exclude_unset=True,
        exclude_defaults=True
    )


DictToModelValidator = convert_dict_to_model_validator_factory
ModelToDictValidator = AfterValidator(convert_model_to_dict)


class Child(SQLModel):
    one: int
    two: float


class Custom(SQLModel):
    boolean: bool = None
    parent: Child


CustomType = Annotated[
                Custom,
                DictToModelValidator(Custom),
                ModelToDictValidator
            ]


class Table(SQLModel, table=True):
    id: int|None = Field(primary_key=True)

    profile: CustomType = Field(sa_type=JSON)

@Anudorannador
Copy link

@MaximilianFranz time goes by, some interface/functions has been changed:

  • parse_obj_as replaced by pydantic_type.validate(xxx)
  • ModelMetaClass moved to a internal/private folder and from pydantic._internal._model_construction import ModelMetaclass

Here is the new implementation from your idea, and simply using JSONB(most case that in postgresql, I think) and using jsonable_encoder as the function recursive_custom_encoder:

import json
from typing import Generic, TypeVar
from sqlalchemy import TypeDecorator
from sqlalchemy.dialects.postgresql import JSONB

from fastapi.encoders import jsonable_encoder
from pydantic._internal._model_construction import ModelMetaclass
T = TypeVar('T')


def pydantic_column_type(pydantic_type):
    class PydanticJSONType(TypeDecorator, Generic[T]):
        impl = JSONB()

        def __init__(
            self, json_encoder=json,
        ):
            self.json_encoder = json_encoder
            super(PydanticJSONType, self).__init__()

        def bind_processor(self, dialect):
            impl_processor = self.impl.bind_processor(dialect)
            dumps = self.json_encoder.dumps
            if impl_processor:

                def process(value: T):
                    if value is not None:
                        if isinstance(pydantic_type, ModelMetaclass):
                            value_to_dump = pydantic_type.from_orm(value)
                        else:
                            value_to_dump = value
                        value = jsonable_encoder(value_to_dump)
                    return impl_processor(value)

            else:

                def process(value):
                    if isinstance(pydantic_type, ModelMetaclass):
                        value_to_dump = pydantic_type.from_orm(value)
                    else:
                        value_to_dump = value
                    value = dumps(jsonable_encoder(value_to_dump))
                    return value

            return process

        def result_processor(self, dialect, coltype) -> T:
            impl_processor = self.impl.result_processor(dialect, coltype)
            if impl_processor:

                def process(value):
                    value = impl_processor(value)
                    if value is None:
                        return None

                    data = value
                    # Explicitly use the generic directly, not type(T)
                    full_obj = pydantic_type.validate(data)
                    return full_obj

            else:

                def process(value):
                    if value is None:
                        return None

                    # Explicitly use the generic directly, not type(T)
                    full_obj = pydantic_type.validate(value)
                    return full_obj

            return process

        def compare_values(self, x, y):
            return x == y

    return PydanticJSONType

usage:

class Outer(SQLModel, table=True):
    nested_obj: NestedClass = Field(sa_column=Column(pydantic_column_type(NestedClass))

From my opition is that the more important is that consideration of security. We can simply store the json data, dict or list form client by this way:

class Outer(SQLModel, table=True):
    nested_obj: Dict | List | None= Field(sa_column=Column(JSONB))

The server dosen't know how large a dict passed from the client, maybe a few, maybe thounds key-values. Using pydantic to filter out the parameters that the server doesn't care and also typing checkout.

Also, if you do not care the convenience on coding, I mean, using the nested object like the way using dict v = nested['key'], instead of nested.key, using @HenningScheufler 's solution.

@drew2a
Copy link

drew2a commented Jan 10, 2025

Hi guys! I experienced the same problem as some of the commenters (unable to deserialize a JSON column into an object). Here’s my solution:

from sqlalchemy import Column, JSON
from sqlalchemy.orm import reconstructor
from sqlmodel import Field, SQLModel


class Item(SQLModel):
    foo: str
    bar: str


class Model(SQLModel, table=True):
    items: list[Item] = Field(sa_column=Column(JSON))

    @reconstructor
    def init_on_load(self):
        if self.items:
            self.items = [Item(**item) for item in self.items]

@iloveitaly
Copy link

I mixed a couple of the solutions here, this is working decently well for me.

@TechLipefi
Copy link

TechLipefi commented Feb 6, 2025

Let's upvote this issue, as it's totally must be functionality.
My rational is the next:
When you developing database you try to make the most vague definition of the tables as possible, as you not sure which data structure will be needed. So you come up with some unnormalized structure with nested models that you have at the moment. On the refinement step you will come with more normalized solution, but still could have general "extra" field for further development.

Now it's pity that we need to come up with custom solution for this.

@iloveitaly
Copy link

True, although it doesn't need to be in sqlmodel. Makes sense to bound the surface area of a project.

This mixin has been working perfectly for me.

@zhouwubai
Copy link

zhouwubai commented Feb 21, 2025

Following @MaximilianFranz @Anudorannador , make it work for List of nested object, work perfectly for JSON as I need more readable content in db

T = TypeVar('T')

# https://github.com/fastapi/sqlmodel/issues/63
def pydantic_column_type(pydantic_type):
    class PydanticJSONType(TypeDecorator, Generic[T]):
        impl = JSON()

        def __init__(
            self, json_encoder=json,
        ):
            self.json_encoder = json_encoder
            super(PydanticJSONType, self).__init__()

        def bind_processor(self, dialect):
            impl_processor = self.impl.bind_processor(dialect)
            dumps = self.json_encoder.dumps
            if impl_processor:
                def process(value: T):
                    if value is not None:
                        if isinstance(value, list) and isinstance(pydantic_type, ModelMetaclass):
                            value_to_dump = [pydantic_type.model_validate(item) for item in value]
                        elif isinstance(pydantic_type, ModelMetaclass):
                            value_to_dump = pydantic_type.model_validate(value)
                        else:
                            value_to_dump = value
                        value = jsonable_encoder(value_to_dump)
                    return impl_processor(value)
            else:
                def process(value):
                    if isinstance(value, list) and isinstance(pydantic_type, ModelMetaclass):
                        value_to_dump = [pydantic_type.model_validate(item) for item in value]
                    elif isinstance(pydantic_type, ModelMetaclass):
                        value_to_dump = pydantic_type.model_validate(value)
                    else:
                        value_to_dump = value
                    value = dumps(jsonable_encoder(value_to_dump))
                    return value
            return process

        def result_processor(self, dialect, coltype) -> T:
            impl_processor = self.impl.result_processor(dialect, coltype)
            if impl_processor:
                def process(value):
                    value = impl_processor(value)
                    if value is None:
                        return None
                    data = value
                    # Explicitly use the generic directly, not type(T)
                    if isinstance(data, list) and isinstance(pydantic_type, ModelMetaclass):
                        full_obj = [pydantic_type.model_validate(item) for item in data]
                    elif isinstance(pydantic_type, ModelMetaclass):
                        full_obj = pydantic_type.model_validate(data)
                    else:
                        full_obj = data
                    return full_obj
            else:
                def process(value):
                    if value is None:
                        return None
                    # Explicitly use the generic directly, not type(T)
                    if isinstance(value, list) and isinstance(pydantic_type, ModelMetaclass):
                        full_obj = [pydantic_type.model_validate(item) for item in value]
                    elif isinstance(pydantic_type, ModelMetaclass):
                        full_obj = pydantic_type.model_validate(value)
                    else:
                        full_obj = value
                    return full_obj
            return process

        def compare_values(self, x, y):
            return x == y

    return PydanticJSONType

Example

from sqlmodel import SQLModel,Relationship,Field,JSON
from typing import Optional,List, Dict
from sqlalchemy import Column
from pydantic import validator


#
class J2(SQLModel):
    id: int
    title:str

#
class Companies(SQLModel, table=True):
    id:Optional[int]=Field(default=None,primary_key=True)
    name:str
    adddresses: List['J2'] = Field(sa_column=Column(pydantic_column_type(J2)))

No need for setting up engine level serializer and derilizer as bind_processor defines serialization behavior and result_processor for deserialization.

@Seluj78
Copy link

Seluj78 commented Mar 3, 2025

I had made this:

class PydanticJSONB(TypeDecorator):
    impl = JSONB

    def __init__(self, model_type: Any, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model_type = model_type
        self._type = get_origin(model_type)
        if self._type is list:
            self._item_type = get_args(model_type)[0]
        elif self._type is dict:
            self._item_type = get_args(model_type)[1]
        else:
            self._item_type = model_type
        self._adapter = TypeAdapter(self.model_type)

    def process_bind_param(self, value: Any, dialect: Any) -> Any:
        if value is None:
            return None

        if self._type is list:
            if not isinstance(value, list):
                raise TypeError(f"Expected list of {self._item_type}")
            return [item.model_dump() if isinstance(item, self._item_type) else item for item in value]
        elif self._type is dict:
            if not isinstance(value, dict):
                raise TypeError(f"Expected dict of {self._item_type}")
            return {k: item.model_dump() if isinstance(item, self._item_type) else item for k, item in value.items()}
        else:
            if isinstance(value, self.model_type):
                return value.model_dump()
            return value

    def process_result_value(self, value: Any, dialect: Any) -> Any:
        if value is not None:
            return self._adapter.validate_python(value)
        if self._type is list:
            return []
        elif self._type is dict:
            return {}
        else:
            return None

Which was then used like this:

class Companies(_BaseModel, table=True):  # type: ignore[call-arg]
    id: str = sm.Field(  # type: ignore[call-overload]
        sa_column=sa.Column("id", sa.String, unique=True, nullable=False, primary_key=True),
        default_factory=generate_random_string,
    )

    persona_settings: Dict[PersonaID, CompanyPersonaSettings] = sm.Field(
        default_factory=dict, sa_column=sa.Column(PydanticJSONB(Dict[PersonaID, CompanyPersonaSettings]))
    )

Where PersonaID = str and CompanyPersonaSettings is a Pydantic BaseModel.

This works for assignment but not for mutations (ie. company.persona_settings[PERSONA_ID].xxx = yyy). I will try @zhouwubai 's solution.

I do think that it should be a part of the base SQLModel :)

@Seluj78
Copy link

Seluj78 commented Mar 3, 2025

Update: @zhouwubai 's code doesn't work for types like Dict[str, ARandomSubModel]

I'm trying to make it work and will post it here if I manage to do it

cc @tiangolo do you have any idea if this will ever be implemented in SQLModel as a "real" feature ? :D

@Seluj78
Copy link

Seluj78 commented Mar 4, 2025

Update: I couldn't manage to track mutations on my PydanticJSONB or make @zhouwubai 's solution work for my case...

Sorry

@Seluj78
Copy link

Seluj78 commented Mar 7, 2025

So actually, this column isn't saved when an assignment or mutation is done onto it:

    conversation_transcript: Optional[List[dict]] = Field(sa_column=Column(JSONB, nullable=True), default=None)

ie.

call.conversation_transcript = new_transcript
# or
call.conversation_transcript.append(entry)

are not flagged as changes so not saved to DB

This is a big problem, and forces me to do a flag_modified on it...

@Seluj78
Copy link

Seluj78 commented Mar 7, 2025

Here is a minimum reproducible example:

import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSONB
import sqlmodel as sm
import uuid
import typing as t

class TestTable(sm.SQLModel, table=True):  # type: ignore[call-arg]
    uuid_id: uuid.UUID = sm.Field(  # type: ignore[call-overload]
        default_factory=uuid.uuid4,
        sa_column=sa.Column("id", sa.UUID, unique=True, nullable=False, primary_key=True),
    )
    conversation_transcript: t.Optional[t.List[dict]] = sm.Field(sa_column=sa.Column(JSONB, nullable=True), default=None)


SYNC_SQLALCHEMY_URL = "postgresql://postgres:postgres@localhost:5432/postgres"

SYNC_DB_ENGINE = sa.create_engine(SYNC_SQLALCHEMY_URL, pool_pre_ping=True)  # type: ignore
SYNC_SESSION_MAKER = sa.orm.sessionmaker(bind=SYNC_DB_ENGINE, class_=sa.orm.Session, expire_on_commit=False, autoflush=True)

sm.SQLModel.metadata.create_all(SYNC_DB_ENGINE)

with SYNC_SESSION_MAKER() as session:
    session.add(TestTable(conversation_transcript=[{"text": "Hello"}]))
    session.commit()

with SYNC_SESSION_MAKER() as session:
    items = session.query(TestTable).all()
    item = items[0]
    print(f"Before update: {item.conversation_transcript} (ID: {item.uuid_id})")
    item_id = item.uuid_id
    item.conversation_transcript.append({"text": "World"})
    print(f"After update before save: {item.conversation_transcript} (ID: {item.uuid_id})")
    session.add(item)
    session.commit()

with SYNC_SESSION_MAKER() as session:
    item = session.get(TestTable, item_id)
    print(f"After update after load: {item.conversation_transcript} (ID: {item.uuid_id})")

@fny
Copy link

fny commented Mar 10, 2025

Following @MaximilianFranz @Anudorannador , make it work for List of nested object, work perfectly for JSON as I need more readable content in db

T = TypeVar('T')

#63

def pydantic_column_type(pydantic_type):
class PydanticJSONType(TypeDecorator, Generic[T]):
impl = JSON()

    def __init__(
        self, json_encoder=json,
    ):
        self.json_encoder = json_encoder
        super(PydanticJSONType, self).__init__()

    def bind_processor(self, dialect):
        impl_processor = self.impl.bind_processor(dialect)
        dumps = self.json_encoder.dumps
        if impl_processor:
            def process(value: T):
                if value is not None:
                    if isinstance(value, list) and isinstance(pydantic_type, ModelMetaclass):
                        value_to_dump = [pydantic_type.model_validate(item) for item in value]
                    elif isinstance(pydantic_type, ModelMetaclass):
                        value_to_dump = pydantic_type.model_validate(value)
                    else:
                        value_to_dump = value
                    value = jsonable_encoder(value_to_dump)
                return impl_processor(value)
        else:
            def process(value):
                if isinstance(value, list) and isinstance(pydantic_type, ModelMetaclass):
                    value_to_dump = [pydantic_type.model_validate(item) for item in value]
                elif isinstance(pydantic_type, ModelMetaclass):
                    value_to_dump = pydantic_type.model_validate(value)
                else:
                    value_to_dump = value
                value = dumps(jsonable_encoder(value_to_dump))
                return value
        return process

    def result_processor(self, dialect, coltype) -> T:
        impl_processor = self.impl.result_processor(dialect, coltype)
        if impl_processor:
            def process(value):
                value = impl_processor(value)
                if value is None:
                    return None
                data = value
                # Explicitly use the generic directly, not type(T)
                if isinstance(data, list) and isinstance(pydantic_type, ModelMetaclass):
                    full_obj = [pydantic_type.model_validate(item) for item in data]
                elif isinstance(pydantic_type, ModelMetaclass):
                    full_obj = pydantic_type.model_validate(data)
                else:
                    full_obj = data
                return full_obj
        else:
            def process(value):
                if value is None:
                    return None
                # Explicitly use the generic directly, not type(T)
                if isinstance(value, list) and isinstance(pydantic_type, ModelMetaclass):
                    full_obj = [pydantic_type.model_validate(item) for item in value]
                elif isinstance(pydantic_type, ModelMetaclass):
                    full_obj = pydantic_type.model_validate(value)
                else:
                    full_obj = value
                return full_obj
        return process

    def compare_values(self, x, y):
        return x == y

return PydanticJSONType

Example

from sqlmodel import SQLModel,Relationship,Field,JSON
from typing import Optional,List, Dict
from sqlalchemy import Column
from pydantic import validator

class J2(SQLModel):
id: int
title:str

class Companies(SQLModel, table=True):
id:Optional[int]=Field(default=None,primary_key=True)
name:str
adddresses: List['J2'] = Field(sa_column=Column(pydantic_column_type(J2)))

No need for setting up engine level serializer and derilizer as bind_processor defines serialization behavior and result_processor for deserialization.

Great work, but seems to be broken. (1) It runs validations on nested types which slows down reads (2) somehow this is triggering other records to be downloaded.

@Seluj78
Copy link

Seluj78 commented Mar 10, 2025

@fny I know, it's certainly not great. If I get a better version I will be sure to post it here, and you can do the same as well, while we wait for an official implementation

@fny
Copy link

fny commented Mar 12, 2025

Hi all, Below is a better implementation based on @iloveiltaly's ActiveModel project. My version is more robust than ActiveModel's. It also assumes all records coming out of the database are valid (i.e. uses Model.construct(...) instead of Model(...)). Additionally, this can handle deeply nested models and other recursive structures.

This tremendously speeds up loading records since validations are skipped. I anticipate this should handle 80% of use cases.

@tiangolo You might find this useful too.

from typing import get_args, get_origin
from pydantic import BaseModel

def convert_field_value(annotation: Any, raw_value: Any) -> Any:
    origin = get_origin(annotation)
    args = get_args(annotation)

    if is_union_type(origin):
        if raw_value is None:
            return None
            # The above is optimistic: it's possible that None is not a valid
            # value for the annotation. The code below would account for that.
            #
            # if type(None) in args:
            #     return None
            # else:
            #     raise ValueError(f"None is not a valid value for the annotation {annotation}")
        for arg in args:
            converted = convert_field_value(arg, raw_value)
            if converted is not None:
                return converted
        return None

    if origin is list:
        return [convert_field_value(args[0], item) for item in raw_value]

    if origin is dict:
        return {
            key: convert_field_value(args[1], value) for key, value in raw_value.items()
        }

    if origin is tuple:
        if len(args) != len(raw_value):
            return raw_value
        return tuple(
            convert_field_value(arg, item) for arg, item in zip(args, raw_value)
        )

    try:
        if issubclass(annotation, BaseModel):
            attrs = {
                field_name: convert_field_value(
                    field_info.annotation, raw_value.get(field_name)
                )
                for field_name, field_info in annotation.model_fields.items()
            }

            return annotation.model_construct(**attrs)
    except TypeError as e:
        if "issubclass()" not in str(e):
            raise e

    return raw_value


class SQLModelJSONMixin:
    @reconstructor
    def init_on_load(self):
        for field_name, field_info in self.model_fields.items():

            raw_value = getattr(self, field_name)
            print(field_name, field_info, raw_value)
            converted = convert_field_value(field_info.annotation, raw_value)
            setattr(self, field_name, converted)

Usage:

class MyModel(SQLModel, SQLModelJSONMixin):
    nested_model: Optional[NestedModel] = Field(sa_type=JSON(), nullable=True)

Other note: you need to call record.init_on_load() after you commit the record to the database, otherwise sqlalchemy will overwrite the field with a dict-like object.

I'm sure there's something I could add to the mixin, but I haven't had time to investigate.

@DaanRademaker
Copy link

DaanRademaker commented Mar 13, 2025

^^ Very nice! Tested the above seems to work great so far!

Also I had to add 2 lines of code to make sure it works with None default values being returned if the type is an array or list (this is technically possible).

  if not origin:  # not a container type (e.g. int, untyped list, None, datetime)
        return raw_value
  if raw_value is None:
      return None

@Seluj78
Copy link

Seluj78 commented Mar 13, 2025

@fny looks great ! Can you provide an example on usage and migrations (with alembic) ?

@fny
Copy link

fny commented Mar 13, 2025

@DaanRademaker I just realized that error myself. I updated my version to make it more robust and also avoid an issue where setattr(...) was triggering validations. @Seluj78: I added an example. Migrations will work without any additional changes.

Other note: you need to call record.init_on_load() after you commit the record to the database, otherwise sqlalchemy will overwrite the field with a dict-like object.

I'm sure there's something I could add to the mixin, but I haven't had time to investigate.

@iloveitaly
Copy link

@fny would love to merge these updates into the active model project if you're up for submitting a PR

@amanmibra
Copy link

I am here to bump this. I would love to see this added!

@pporcher
Copy link

Here is how I do it using pydantic's TypeAdapter.

from sqlalchemy import TypeDecorator
from sqlmodel import JSON
from pydantic import TypeAdapter

class PydanticJson(TypeDecorator):
    impl = JSON()
    cache_ok = True

    def __init__(self, pt):
        super().__init__()
        self.pt = TypeAdapter(pt)
        self.coerce_compared_value = self.impl.coerce_compared_value

    def bind_processor(self, dialect):
        return lambda value: self.pt.dump_json(value) if value is not None else None

    def result_processor(self, dialect, coltype):
        return lambda value: self.pt.validate_json(value) if value is not None else None

And how to use it.

from sqlalchemy import Column
from pydantic import BaseModel
from sqlmodel import SQLModel, Field

class Nested(BaseModel):
    value: str

class Parent(SQLModel, table=True):
    id: int = Field(primary_key=True, default=None)
    nested: Nested | None = Field(sa_column=Column(PydanticJson(Nested)))
    nested_list: list[Nested] = Field(sa_column=Column(PydanticJson(list[Nested])))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests