diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 3532e81a8e..d377e62d50 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -561,6 +561,7 @@ def get_config(name: str) -> Any: # If it was passed by kwargs, ensure it's also set in config set_config_value(model=new_cls, parameter="table", value=config_table) for k, v in get_model_fields(new_cls).items(): + # convert FieldInfo definitions into sqlalchemy columns col = get_column_from_field(v) setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field @@ -575,6 +576,12 @@ def get_config(name: str) -> Any: # TODO: remove this in the future set_config_value(model=new_cls, parameter="read_with_orm_mode", value=True) + # enables field-level docstrings on the pydanatic `description` field, which we then copy into + # sa_args, which is persisted to sql table comments + set_config_value( + model=new_cls, parameter="use_attribute_docstrings", value=True + ) + config_registry = get_config("registry") if config_registry is not Undefined: config_registry = cast(registry, config_registry) @@ -635,6 +642,7 @@ def __init__( rel_args.extend(rel_info.sa_relationship_args) if rel_info.sa_relationship_kwargs: rel_kwargs.update(rel_info.sa_relationship_kwargs) + # this where RelationshipInfo objects are converted to lazy column evaluators rel_value = relationship(relationship_to, *rel_args, **rel_kwargs) setattr(cls, rel_name, rel_value) # Fix #315 # SQLAlchemy no longer uses dict_ @@ -702,21 +710,32 @@ def get_sqlalchemy_type(field: Any) -> Any: raise ValueError(f"{type_} has no matching SQLAlchemy type") -def get_column_from_field(field: Any) -> Column: # type: ignore +def get_column_from_field(field: PydanticFieldInfo | FieldInfo) -> Column: # type: ignore + """ + Takes a field definition, which can either come from the sqlmodel FieldInfo class or the pydantic variant of that class, + and converts it into a sqlalchemy Column object. + """ if IS_PYDANTIC_V2: field_info = field else: field_info = field.field_info + sa_column = getattr(field_info, "sa_column", Undefined) if isinstance(sa_column, Column): + # if a db field comment is not already defined, and a description exists on the field, add it to the column definition + if not sa_column.comment and field_info.description: + sa_column.comment = field_info.description + return sa_column - sa_type = get_sqlalchemy_type(field) + primary_key = getattr(field_info, "primary_key", Undefined) if primary_key is Undefined: primary_key = False + index = getattr(field_info, "index", Undefined) if index is Undefined: index = False + nullable = not primary_key and is_field_noneable(field) # Override derived nullability if the nullable property is set explicitly # on the field @@ -746,6 +765,7 @@ def get_column_from_field(field: Any) -> Column: # type: ignore "index": index, "unique": unique, } + sa_default = Undefined if field_info.default_factory: sa_default = field_info.default_factory @@ -753,12 +773,29 @@ def get_column_from_field(field: Any) -> Column: # type: ignore sa_default = field_info.default if sa_default is not Undefined: kwargs["default"] = sa_default + sa_column_args = getattr(field_info, "sa_column_args", Undefined) if sa_column_args is not Undefined: args.extend(list(cast(Sequence[Any], sa_column_args))) + sa_column_kwargs = getattr(field_info, "sa_column_kwargs", Undefined) + + if field_info.description: + if sa_column_kwargs is Undefined: + sa_column_kwargs = {} + + assert isinstance(sa_column_kwargs, dict) + + # only update comments if not already set + if "comment" not in sa_column_kwargs: + sa_column_kwargs["comment"] = field_info.description + if sa_column_kwargs is not Undefined: kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) + + sa_type = get_sqlalchemy_type(field) + + # if sa_column is not specified, then the column is constructed here return Column(sa_type, *args, **kwargs) # type: ignore diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index 512daacbab..df1243eb0a 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -5,6 +5,12 @@ class AutoString(types.TypeDecorator): # type: ignore + """ + Determines the best sqlalchemy string type based on the database dialect. + + For example, when using Postgres this will return sqlalchemy's String() + """ + impl = types.String cache_ok = True mysql_default_length = 255