Skip to content

Commit d80bf00

Browse files
committed
[stubgenc] Recognize pybind11 static properties
1 parent 5538a68 commit d80bf00

File tree

3 files changed

+52
-48
lines changed

3 files changed

+52
-48
lines changed

mypy/stubgenc.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
_DEFAULT_TYPING_IMPORTS = (
2323
'Any',
2424
'Callable',
25+
'ClassVar',
2526
'Dict',
2627
'Iterable',
2728
'Iterator',
@@ -243,7 +244,13 @@ def strip_or_import(typ: str, module: ModuleType, imports: List[str]) -> str:
243244
return stripped_type
244245

245246

246-
def generate_c_property_stub(name: str, obj: object, output: List[str], readonly: bool,
247+
def is_static_property(obj: object) -> bool:
248+
return type(obj).__name__ == 'pybind11_static_property'
249+
250+
251+
def generate_c_property_stub(name: str, obj: object,
252+
static_properties: List[str],
253+
properties: List[str], readonly: bool,
247254
module: Optional[ModuleType] = None,
248255
imports: Optional[List[str]] = None) -> None:
249256
"""Generate property stub using introspection of 'obj'.
@@ -273,11 +280,17 @@ def infer_prop_type(docstr: Optional[str]) -> Optional[str]:
273280
if module is not None and imports is not None:
274281
inferred = strip_or_import(inferred, module, imports)
275282

276-
output.append('@property')
277-
output.append('def {}(self) -> {}: ...'.format(name, inferred))
278-
if not readonly:
279-
output.append('@{}.setter'.format(name))
280-
output.append('def {}(self, val: {}) -> None: ...'.format(name, inferred))
283+
if is_static_property(obj):
284+
trailing_comment = " # read-only" if readonly else ""
285+
static_properties.append(
286+
'{}: ClassVar[{}] = ...{}'.format(name, inferred, trailing_comment)
287+
)
288+
else: # regular property
289+
properties.append('@property')
290+
properties.append('def {}(self) -> {}: ...'.format(name, inferred))
291+
if not readonly:
292+
properties.append('@{}.setter'.format(name))
293+
properties.append('def {}(self, val: {}) -> None: ...'.format(name, inferred))
281294

282295

283296
def generate_c_type_stub(module: ModuleType,
@@ -298,6 +311,7 @@ def generate_c_type_stub(module: ModuleType,
298311
items = sorted(obj_dict.items(), key=lambda x: method_name_sort_key(x[0]))
299312
methods = [] # type: List[str]
300313
types = [] # type: List[str]
314+
static_properties = [] # type: List[str]
301315
properties = [] # type: List[str]
302316
done = set() # type: Set[str]
303317
for attr, value in items:
@@ -322,19 +336,19 @@ def generate_c_type_stub(module: ModuleType,
322336
class_sigs=class_sigs)
323337
elif is_c_property(value):
324338
done.add(attr)
325-
generate_c_property_stub(attr, value, properties, is_c_property_readonly(value),
339+
generate_c_property_stub(attr, value, static_properties, properties,
340+
is_c_property_readonly(value),
326341
module=module, imports=imports)
327342
elif is_c_type(value):
328343
generate_c_type_stub(module, attr, value, types, imports=imports, sigs=sigs,
329344
class_sigs=class_sigs)
330345
done.add(attr)
331346

332-
variables = []
333347
for attr, value in items:
334348
if is_skipped_attribute(attr):
335349
continue
336350
if attr not in done:
337-
variables.append('%s: %s = ...' % (
351+
static_properties.append('%s: ClassVar[%s] = ...' % (
338352
attr, strip_or_import(get_type_fullname(type(value)), module, imports)))
339353
all_bases = obj.mro()
340354
if all_bases[-1] is object:
@@ -361,21 +375,21 @@ def generate_c_type_stub(module: ModuleType,
361375
)
362376
else:
363377
bases_str = ''
364-
if not methods and not variables and not properties and not types:
365-
output.append('class %s%s: ...' % (class_name, bases_str))
366-
else:
378+
if types or static_properties or methods or properties:
367379
output.append('class %s%s:' % (class_name, bases_str))
368380
for line in types:
369381
if output and output[-1] and \
370382
not output[-1].startswith('class') and line.startswith('class'):
371383
output.append('')
372384
output.append(' ' + line)
373-
for variable in variables:
374-
output.append(' %s' % variable)
375-
for method in methods:
376-
output.append(' %s' % method)
377-
for prop in properties:
378-
output.append(' %s' % prop)
385+
for line in static_properties:
386+
output.append(' %s' % line)
387+
for line in methods:
388+
output.append(' %s' % line)
389+
for line in properties:
390+
output.append(' %s' % line)
391+
else:
392+
output.append('class %s%s: ...' % (class_name, bases_str))
379393

380394

381395
def get_type_fullname(typ: type) -> str:

mypy/test/teststubgen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,7 @@ class TestClassVariableCls:
676676
mod = ModuleType('module', '') # any module is fine
677677
generate_c_type_stub(mod, 'C', TestClassVariableCls, output, imports)
678678
assert_equal(imports, [])
679-
assert_equal(output, ['class C:', ' x: int = ...'])
679+
assert_equal(output, ['class C:', ' x: ClassVar[int] = ...'])
680680

681681
def test_generate_c_type_inheritance(self) -> None:
682682
class TestClass(KeyError):
@@ -815,7 +815,7 @@ def get_attribute(self) -> None:
815815
attribute = property(get_attribute, doc="")
816816

817817
output = [] # type: List[str]
818-
generate_c_property_stub('attribute', TestClass.attribute, output, readonly=True)
818+
generate_c_property_stub('attribute', TestClass.attribute, [], output, readonly=True)
819819
assert_equal(output, ['@property', 'def attribute(self) -> str: ...'])
820820

821821
def test_generate_c_type_with_single_arg_generic(self) -> None:

test-data/stubgen/pybind11_mypy_demo/basics.pyi

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
from typing import ClassVar
2+
13
from typing import overload
24
PI: float
35

46
class Point:
57
class AngleUnit:
6-
__entries: dict = ...
7-
degree: Point.AngleUnit = ...
8-
radian: Point.AngleUnit = ...
8+
__doc__: ClassVar[str] = ... # read-only
9+
__members__: ClassVar[dict] = ... # read-only
10+
__entries: ClassVar[dict] = ...
11+
degree: ClassVar[Point.AngleUnit] = ...
12+
radian: ClassVar[Point.AngleUnit] = ...
913
def __init__(self, value: int) -> None: ...
1014
def __eq__(self, other: object) -> bool: ...
1115
def __getstate__(self) -> int: ...
@@ -16,16 +20,14 @@ class Point:
1620
def __setstate__(self, state: int) -> None: ...
1721
@property
1822
def name(self) -> str: ...
19-
@property
20-
def __doc__(self) -> str: ...
21-
@property
22-
def __members__(self) -> dict: ...
2323

2424
class LengthUnit:
25-
__entries: dict = ...
26-
inch: Point.LengthUnit = ...
27-
mm: Point.LengthUnit = ...
28-
pixel: Point.LengthUnit = ...
25+
__doc__: ClassVar[str] = ... # read-only
26+
__members__: ClassVar[dict] = ... # read-only
27+
__entries: ClassVar[dict] = ...
28+
inch: ClassVar[Point.LengthUnit] = ...
29+
mm: ClassVar[Point.LengthUnit] = ...
30+
pixel: ClassVar[Point.LengthUnit] = ...
2931
def __init__(self, value: int) -> None: ...
3032
def __eq__(self, other: object) -> bool: ...
3133
def __getstate__(self) -> int: ...
@@ -36,11 +38,11 @@ class Point:
3638
def __setstate__(self, state: int) -> None: ...
3739
@property
3840
def name(self) -> str: ...
39-
@property
40-
def __doc__(self) -> str: ...
41-
@property
42-
def __members__(self) -> dict: ...
43-
origin: Point = ...
41+
angle_unit: ClassVar[Point.AngleUnit] = ...
42+
length_unit: ClassVar[Point.LengthUnit] = ...
43+
x_axis: ClassVar[Point] = ... # read-only
44+
y_axis: ClassVar[Point] = ... # read-only
45+
origin: ClassVar[Point] = ...
4446
@overload
4547
def __init__(self) -> None: ...
4648
@overload
@@ -50,27 +52,15 @@ class Point:
5052
@overload
5153
def distance_to(self, other: Point) -> float: ...
5254
@property
53-
def angle_unit(self) -> Point.AngleUnit: ...
54-
@angle_unit.setter
55-
def angle_unit(self, val: Point.AngleUnit) -> None: ...
56-
@property
5755
def length(self) -> float: ...
5856
@property
59-
def length_unit(self) -> Point.LengthUnit: ...
60-
@length_unit.setter
61-
def length_unit(self, val: Point.LengthUnit) -> None: ...
62-
@property
6357
def x(self) -> float: ...
6458
@x.setter
6559
def x(self, val: float) -> None: ...
6660
@property
67-
def x_axis(self) -> Point: ...
68-
@property
6961
def y(self) -> float: ...
7062
@y.setter
7163
def y(self, val: float) -> None: ...
72-
@property
73-
def y_axis(self) -> Point: ...
7464

7565
def answer() -> int: ...
7666
def midpoint(left: float, right: float) -> float: ...

0 commit comments

Comments
 (0)