forked from samuelcolvin/python-devtools
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprettier.py
354 lines (300 loc) · 13.3 KB
/
prettier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
import ast
import io
import os
from collections import OrderedDict
from collections.abc import Generator
from .utils import DataClassType, LaxMapping, SQLAlchemyClassType, env_true, isatty
try:
from functools import cache
except ImportError:
from functools import lru_cache
cache = lru_cache()
try:
from sqlalchemy import inspect as sa_inspect
from sqlalchemy.orm import Mapper as sa_Mapper
from sqlalchemy.sql.schema import Table as sa_Table
except ImportError:
sa_inspect = None # type: ignore[assignment]
__all__ = 'PrettyFormat', 'pformat', 'pprint'
MYPY = False
if MYPY:
from typing import Any, Callable, Iterable, List, Set, Tuple, Union
PARENTHESES_LOOKUP = [
(list, '[', ']'),
(set, '{', '}'),
(frozenset, 'frozenset({', '})'),
]
DEFAULT_WIDTH = int(os.getenv('PY_DEVTOOLS_WIDTH', 120))
MISSING = object()
PRETTY_KEY = '__prettier_formatted_value__'
def fmt(v: 'Any') -> 'Any':
return {PRETTY_KEY: v}
class SkipPretty(Exception):
pass
@cache
def get_pygments() -> 'Tuple[Any, Any, Any]':
try:
import pygments
from pygments.formatters import Terminal256Formatter
from pygments.lexers import PythonLexer
except ImportError: # pragma: no cover
return None, None, None
else:
return pygments, PythonLexer(), Terminal256Formatter(style='vim')
# common generator types (this is not exhaustive: things like chain are not include to avoid the import)
generator_types = Generator, map, filter, zip, enumerate
class PrettyFormat:
def __init__(
self,
indent_step: int = 4,
indent_char: str = ' ',
repr_strings: bool = False,
simple_cutoff: int = 10,
width: int = 120,
yield_from_generators: bool = True,
):
self._indent_step = indent_step
self._c = indent_char
self._repr_strings = repr_strings
self._repr_generators = not yield_from_generators
self._simple_cutoff = simple_cutoff
self._width = width
self._type_lookup: 'List[Tuple[Any, Callable[[Any, str, int, int], None]]]' = [
(dict, self._format_dict),
((str, bytes), self._format_str_bytes),
(tuple, self._format_tuples),
((list, set, frozenset), self._format_list_like),
(bytearray, self._format_bytearray),
(generator_types, self._format_generator),
# put these last as the check can be slow
(ast.AST, self._format_ast_expression),
(LaxMapping, self._format_dict),
(DataClassType, self._format_dataclass),
(SQLAlchemyClassType, self._format_sqlalchemy_class),
]
self.visited: 'Set[int]' = set()
def __call__(self, value: 'Any', *, indent: int = 0, indent_first: bool = False, highlight: bool = False) -> str:
self._stream = io.StringIO()
self._format(value, indent_current=indent, indent_first=indent_first)
s = self._stream.getvalue()
pygments, pyg_lexer, pyg_formatter = get_pygments()
if highlight and pygments:
# apparently highlight adds a trailing new line we don't want
s = pygments.highlight(s, lexer=pyg_lexer, formatter=pyg_formatter).rstrip('\n')
return s
def _format(self, value: 'Any', indent_current: int, indent_first: bool) -> None:
if indent_first:
self._stream.write(indent_current * self._c)
try:
pretty_func = getattr(value, '__pretty__')
except AttributeError:
pass
else:
# `pretty_func.__class__.__name__ == 'method'` should only be true for bound methods,
# `hasattr(pretty_func, '__self__')` is more canonical but weirdly is true for unbound cython functions
from unittest.mock import _Call as MockCall
if pretty_func.__class__.__name__ == 'method' and not isinstance(value, MockCall):
try:
gen = pretty_func(fmt=fmt, skip_exc=SkipPretty)
self._render_pretty(gen, indent_current)
except SkipPretty:
pass
else:
return None
value_repr = repr(value)
if len(value_repr) <= self._simple_cutoff and not isinstance(value, generator_types):
self._stream.write(value_repr)
else:
indent_new = indent_current + self._indent_step
for t, func in self._type_lookup:
if isinstance(value, t):
func(value, value_repr, indent_current, indent_new)
return None
self._format_raw(value, value_repr, indent_current, indent_new)
def _render_pretty(self, gen: 'Iterable[Any]', indent: int) -> None:
prefix = False
for v in gen:
if isinstance(v, int) and v in {-1, 0, 1}:
indent += v * self._indent_step
prefix = True
else:
if prefix:
self._stream.write('\n' + self._c * indent)
prefix = False
pretty_value = v.get(PRETTY_KEY, MISSING) if (isinstance(v, dict) and len(v) == 1) else MISSING
if pretty_value is not MISSING:
self._format(pretty_value, indent, False)
elif isinstance(v, str):
self._stream.write(v)
else:
# shouldn't happen but will
self._stream.write(repr(v))
def _format_dict(self, value: 'Any', _: str, indent_current: int, indent_new: int) -> None:
open_, before_, split_, after_, close_ = '{\n', indent_new * self._c, ': ', ',\n', '}'
if isinstance(value, OrderedDict):
open_, split_, after_, close_ = 'OrderedDict([\n', ', ', '),\n', '])'
before_ += '('
elif type(value) != dict:
open_, close_ = f'<{value.__class__.__name__}({{\n', '})>'
self._stream.write(open_)
for k, v in value.items():
self._stream.write(before_)
self._format(k, indent_new, False)
self._stream.write(split_)
self._format(v, indent_new, False)
self._stream.write(after_)
self._stream.write(indent_current * self._c + close_)
def _format_list_like(
self, value: 'Union[List[Any], Tuple[Any, ...], Set[Any]]', _: str, indent_current: int, indent_new: int
) -> None:
open_, close_ = '(', ')'
for t, *oc in PARENTHESES_LOOKUP:
if isinstance(value, t):
open_, close_ = oc
break
self._stream.write(open_ + '\n')
for v in value:
self._format(v, indent_new, True)
self._stream.write(',\n')
self._stream.write(indent_current * self._c + close_)
def _format_tuples(self, value: 'Tuple[Any, ...]', value_repr: str, indent_current: int, indent_new: int) -> None:
fields = getattr(value, '_fields', None)
if fields:
# named tuple
self._format_fields(value, zip(fields, value), indent_current, indent_new)
else:
# normal tuples are just like other similar iterables
self._format_list_like(value, value_repr, indent_current, indent_new)
def _format_str_bytes(
self, value: 'Union[str, bytes]', value_repr: str, indent_current: int, indent_new: int
) -> None:
if self._repr_strings:
self._stream.write(value_repr)
else:
lines = list(self._wrap_lines(value, indent_new))
if len(lines) > 1:
self._str_lines(lines, indent_current, indent_new)
else:
self._stream.write(value_repr)
def _str_lines(self, lines: 'Iterable[Union[str, bytes]]', indent_current: int, indent_new: int) -> None:
self._stream.write('(\n')
prefix = indent_new * self._c
for line in lines:
self._stream.write(prefix + repr(line) + '\n')
self._stream.write(indent_current * self._c + ')')
def _wrap_lines(self, s: 'Union[str, bytes]', indent_new: int) -> 'Generator[Union[str, bytes], None, None]':
width = self._width - indent_new - 3
for line in s.splitlines(True):
start = 0
for pos in range(width, len(line), width):
yield line[start:pos]
start = pos
yield line[start:]
def _format_generator(
self, value: 'Generator[Any, None, None]', value_repr: str, indent_current: int, indent_new: int
) -> None:
if self._repr_generators:
self._stream.write(value_repr)
else:
name = value.__class__.__name__
if name == 'generator':
# no name if the name is just "generator"
self._stream.write('(\n')
else:
self._stream.write(f'{name}(\n')
for v in value:
self._format(v, indent_new, True)
self._stream.write(',\n')
self._stream.write(indent_current * self._c + ')')
def _format_bytearray(self, value: 'Any', _: str, indent_current: int, indent_new: int) -> None:
self._stream.write('bytearray')
lines = self._wrap_lines(bytes(value), indent_new)
self._str_lines(lines, indent_current, indent_new)
def _format_ast_expression(self, value: ast.AST, _: str, indent_current: int, indent_new: int) -> None:
try:
s = ast.dump(value, indent=self._indent_step)
except TypeError:
# no indent before 3.9
s = ast.dump(value)
lines = s.splitlines(True)
self._stream.write(lines[0])
for line in lines[1:]:
self._stream.write(indent_current * self._c + line)
def _format_dataclass(self, value: 'Any', _: str, indent_current: int, indent_new: int) -> None:
try:
field_items = value.__dict__.items()
except AttributeError:
# slots
field_items = ((f, getattr(value, f)) for f in value.__slots__)
self._format_fields(value, field_items, indent_current, indent_new)
def _format_sqlalchemy_visited(self, value: 'Any') -> None:
if sa_inspect is None:
self._stream.write(f'"<visited {value!r})>"')
return
inst_state = sa_inspect(value, raiseerr=False)
if inst_state is None or not isinstance(inst_state.mapper, sa_Mapper):
self._stream.write(f'"<visited {value!r})>"')
return
mapper = inst_state.mapper
if isinstance(mapper.persist_selectable, sa_Table):
tablename = mapper.persist_selectable.name
else:
tablename = mapper.class_.__name__
unloaded_orm_fields = inst_state.unloaded
fields_list = []
for c in mapper.columns:
if not c.primary_key or c.name in unloaded_orm_fields:
continue
try:
_value = getattr(value, c.name)
fields_list.append(f'{c.name}={_value}')
except AttributeError:
pass
fields = ', '.join(fields_list)
self._stream.write(f'"<visited {tablename}({fields})>"')
def _format_sqlalchemy_class(self, value: 'Any', _: str, indent_current: int, indent_new: int) -> None:
if id(value) in self.visited:
self._format_sqlalchemy_visited(value)
return
self.visited.add(id(value))
if sa_inspect is not None:
state = sa_inspect(value)
deferred = state.unloaded
else:
deferred = set()
fields = [
(field, getattr(value, field) if field not in deferred else '<deferred>')
for field in dir(value)
if not (field.startswith('_') or field in ['metadata', 'registry', 'awaitable_attrs'])
]
self._format_fields(value, fields, indent_current, indent_new)
def _format_raw(self, _: 'Any', value_repr: str, indent_current: int, indent_new: int) -> None:
lines = value_repr.splitlines(True)
if len(lines) > 1 or (len(value_repr) + indent_current) >= self._width:
self._stream.write('(\n')
wrap_at = self._width - indent_new
prefix = indent_new * self._c
from textwrap import wrap
for line in lines:
sub_lines = wrap(line, wrap_at)
for sline in sub_lines:
self._stream.write(prefix + sline + '\n')
self._stream.write(indent_current * self._c + ')')
else:
self._stream.write(value_repr)
def _format_fields(
self, value: 'Any', fields: 'Iterable[Tuple[str, Any]]', indent_current: int, indent_new: int
) -> None:
self._stream.write(f'{value.__class__.__name__}(\n')
for field, v in fields:
self._stream.write(indent_new * self._c)
if field: # field is falsy sometimes for odd things like call_args
self._stream.write(f'{field}=')
self._format(v, indent_new, False)
self._stream.write(',\n')
self._stream.write(indent_current * self._c + ')')
pformat = PrettyFormat()
force_highlight = env_true('PY_DEVTOOLS_HIGHLIGHT', None)
def pprint(s: 'Any', file: 'Any' = None) -> None:
highlight = isatty(file) if force_highlight is None else force_highlight
print(pformat(s, highlight=highlight), file=file, flush=True)