15
15
16
16
try :
17
17
from sqlalchemy import inspect as sa_inspect
18
+ from sqlalchemy .orm import Mapper as sa_Mapper
19
+ from sqlalchemy .sql .schema import Table as sa_Table
18
20
except ImportError :
19
21
sa_inspect = None # type: ignore[assignment]
20
22
@@ -86,6 +88,7 @@ def __init__(
86
88
(DataClassType , self ._format_dataclass ),
87
89
(SQLAlchemyClassType , self ._format_sqlalchemy_class ),
88
90
]
91
+ self .visited : 'Set[int]' = set ()
89
92
90
93
def __call__ (self , value : 'Any' , * , indent : int = 0 , indent_first : bool = False , highlight : bool = False ) -> str :
91
94
self ._stream = io .StringIO ()
@@ -261,7 +264,44 @@ def _format_dataclass(self, value: 'Any', _: str, indent_current: int, indent_ne
261
264
field_items = ((f , getattr (value , f )) for f in value .__slots__ )
262
265
self ._format_fields (value , field_items , indent_current , indent_new )
263
266
267
+ def _format_sqlalchemy_visited (self , value : 'Any' ) -> None :
268
+ if sa_inspect is None :
269
+ self ._stream .write (f'"<visited { value !r} )>"' )
270
+ return
271
+
272
+ inst_state = sa_inspect (value , raiseerr = False )
273
+ if inst_state is None or not isinstance (inst_state .mapper , sa_Mapper ):
274
+ self ._stream .write (f'"<visited { value !r} )>"' )
275
+ return
276
+
277
+ mapper = inst_state .mapper
278
+ if isinstance (mapper .persist_selectable , sa_Table ):
279
+ tablename = mapper .persist_selectable .name
280
+ else :
281
+ tablename = mapper .class_ .__name__
282
+
283
+ unloaded_orm_fields = inst_state .unloaded
284
+ fields_list = []
285
+ for c in mapper .columns :
286
+ if not c .primary_key or c .name in unloaded_orm_fields :
287
+ continue
288
+
289
+ try :
290
+ _value = getattr (value , c .name )
291
+ fields_list .append (f'{ c .name } ={ _value } ' )
292
+ except AttributeError :
293
+ pass
294
+
295
+ fields = ', ' .join (fields_list )
296
+ self ._stream .write (f'"<visited { tablename } ({ fields } )>"' )
297
+
264
298
def _format_sqlalchemy_class (self , value : 'Any' , _ : str , indent_current : int , indent_new : int ) -> None :
299
+ if id (value ) in self .visited :
300
+ self ._format_sqlalchemy_visited (value )
301
+ return
302
+
303
+ self .visited .add (id (value ))
304
+
265
305
if sa_inspect is not None :
266
306
state = sa_inspect (value )
267
307
deferred = state .unloaded
@@ -271,7 +311,7 @@ def _format_sqlalchemy_class(self, value: 'Any', _: str, indent_current: int, in
271
311
fields = [
272
312
(field , getattr (value , field ) if field not in deferred else '<deferred>' )
273
313
for field in dir (value )
274
- if not (field .startswith ('_' ) or field in ['metadata' , 'registry' ])
314
+ if not (field .startswith ('_' ) or field in ['metadata' , 'registry' , 'awaitable_attrs' ])
275
315
]
276
316
self ._format_fields (value , fields , indent_current , indent_new )
277
317
0 commit comments