diff --git a/devtools/prettier.py b/devtools/prettier.py index c45bc6a..c5eacb7 100644 --- a/devtools/prettier.py +++ b/devtools/prettier.py @@ -15,6 +15,8 @@ try: from sqlalchemy import inspect as sa_inspect + from sqlalchemy.orm import Mapper as sa_Mapper + from sqlalchemy.sql.schema import Table as sa_Table except ImportError: sa_inspect = None # type: ignore[assignment] @@ -86,6 +88,7 @@ def __init__( (DataClassType, self._format_dataclass), (SQLAlchemyClassType, self._format_sqlalchemy_class), ] + self.visited: 'Set[int]' = set() def __call__(self, value: 'Any', *, indent: int = 0, indent_first: bool = False, highlight: bool = False) -> str: self._stream = io.StringIO() @@ -261,7 +264,44 @@ def _format_dataclass(self, value: 'Any', _: str, indent_current: int, indent_ne field_items = ((f, getattr(value, f)) for f in value.__slots__) self._format_fields(value, field_items, indent_current, indent_new) + def _format_sqlalchemy_visited(self, value: 'Any') -> None: + if sa_inspect is None: + self._stream.write(f'""') + return + + inst_state = sa_inspect(value, raiseerr=False) + if inst_state is None or not isinstance(inst_state.mapper, sa_Mapper): + self._stream.write(f'""') + return + + mapper = inst_state.mapper + if isinstance(mapper.persist_selectable, sa_Table): + tablename = mapper.persist_selectable.name + else: + tablename = mapper.class_.__name__ + + unloaded_orm_fields = inst_state.unloaded + fields_list = [] + for c in mapper.columns: + if not c.primary_key or c.name in unloaded_orm_fields: + continue + + try: + _value = getattr(value, c.name) + fields_list.append(f'{c.name}={_value}') + except AttributeError: + pass + + fields = ', '.join(fields_list) + self._stream.write(f'""') + def _format_sqlalchemy_class(self, value: 'Any', _: str, indent_current: int, indent_new: int) -> None: + if id(value) in self.visited: + self._format_sqlalchemy_visited(value) + return + + self.visited.add(id(value)) + if sa_inspect is not None: state = sa_inspect(value) deferred = state.unloaded @@ -271,7 +311,7 @@ def _format_sqlalchemy_class(self, value: 'Any', _: str, indent_current: int, in fields = [ (field, getattr(value, field) if field not in deferred else '') for field in dir(value) - if not (field.startswith('_') or field in ['metadata', 'registry']) + if not (field.startswith('_') or field in ['metadata', 'registry', 'awaitable_attrs']) ] self._format_fields(value, fields, indent_current, indent_new)