Skip to content

Commit 3d372a0

Browse files
author
Victor Naumov
committed
fix 'maximum recursion depth is exceeded' for sqlalchemy objects
1 parent 5022f1f commit 3d372a0

File tree

1 file changed

+41
-1
lines changed

1 file changed

+41
-1
lines changed

devtools/prettier.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
try:
1717
from sqlalchemy import inspect as sa_inspect
18+
from sqlalchemy.orm import Mapper as sa_Mapper
19+
from sqlalchemy.sql.schema import Table as sa_Table
1820
except ImportError:
1921
sa_inspect = None # type: ignore[assignment]
2022

@@ -86,6 +88,7 @@ def __init__(
8688
(DataClassType, self._format_dataclass),
8789
(SQLAlchemyClassType, self._format_sqlalchemy_class),
8890
]
91+
self.visited: 'Set[int]' = set()
8992

9093
def __call__(self, value: 'Any', *, indent: int = 0, indent_first: bool = False, highlight: bool = False) -> str:
9194
self._stream = io.StringIO()
@@ -261,7 +264,44 @@ def _format_dataclass(self, value: 'Any', _: str, indent_current: int, indent_ne
261264
field_items = ((f, getattr(value, f)) for f in value.__slots__)
262265
self._format_fields(value, field_items, indent_current, indent_new)
263266

267+
def _format_sqlalchemy_visited(self, value: 'Any') -> None:
268+
if sa_inspect is None:
269+
self._stream.write(f'"<visited {value!r})>"')
270+
return
271+
272+
inst_state = sa_inspect(value, raiseerr=False)
273+
if inst_state is None or not isinstance(inst_state.mapper, sa_Mapper):
274+
self._stream.write(f'"<visited {value!r})>"')
275+
return
276+
277+
mapper = inst_state.mapper
278+
if isinstance(mapper.persist_selectable, sa_Table):
279+
tablename = mapper.persist_selectable.name
280+
else:
281+
tablename = mapper.class_.__name__
282+
283+
unloaded_orm_fields = inst_state.unloaded
284+
fields_list = []
285+
for c in mapper.columns:
286+
if not c.primary_key or c.name in unloaded_orm_fields:
287+
continue
288+
289+
try:
290+
_value = getattr(value, c.name)
291+
fields_list.append(f'{c.name}={_value}')
292+
except AttributeError:
293+
pass
294+
295+
fields = ', '.join(fields_list)
296+
self._stream.write(f'"<visited {tablename}({fields})>"')
297+
264298
def _format_sqlalchemy_class(self, value: 'Any', _: str, indent_current: int, indent_new: int) -> None:
299+
if id(value) in self.visited:
300+
self._format_sqlalchemy_visited(value)
301+
return
302+
303+
self.visited.add(id(value))
304+
265305
if sa_inspect is not None:
266306
state = sa_inspect(value)
267307
deferred = state.unloaded
@@ -271,7 +311,7 @@ def _format_sqlalchemy_class(self, value: 'Any', _: str, indent_current: int, in
271311
fields = [
272312
(field, getattr(value, field) if field not in deferred else '<deferred>')
273313
for field in dir(value)
274-
if not (field.startswith('_') or field in ['metadata', 'registry'])
314+
if not (field.startswith('_') or field in ['metadata', 'registry', 'awaitable_attrs'])
275315
]
276316
self._format_fields(value, fields, indent_current, indent_new)
277317

0 commit comments

Comments
 (0)